jax.experimental.pallas.swap# jax.experimental.pallas.swap(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None, _function_name='swap')[source]# Swaps the value at the given index and returns the old value. See load() for the meaning of the arguments. Returns: The value stored in the ref prior to the swap. Return type: jax.Array