jax.experimental.pallas.atomic_cas#
- jax.experimental.pallas.atomic_cas(ref, cmp, val)[source]#
Performs an atomic compare-and-swap of the value in the ref with the given value.
- Parameters:
ref – The ref to operate on.
cmp – The expected value to compare against.
val – The value to swap in.
- Returns:
The value at the given index prior to the atomic operation.