jax.experimental.pallas.swap#

jax.experimental.pallas.swap(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None, _function_name='swap')[source]#

Swaps the value at the given index and returns the old value.

See load() for the meaning of the arguments.

Returns:

The value stored in the ref prior to the swap.

Return type:

jax.Array