jax.experimental.pallas.atomic_cas#

jax.experimental.pallas.atomic_cas(ref, cmp, val)[source]#

Performs an atomic compare-and-swap of the value in the ref with the given value.

Parameters:
  • ref – The ref to operate on.

  • cmp – The expected value to compare against.

  • val – The value to swap in.

Returns:

The value at the given index prior to the atomic operation.