jax.experimental.pallas.load#

jax.experimental.pallas.load(x_ref_or_view, idx, *, mask=None, other=None, cache_modifier=None, eviction_policy=None, volatile=False)[source]#

Returns an array loaded from the given index.

If neither mask nor other is specified, this function has the same semantics as x_ref_or_view[idx] in JAX.

Parameters:
  • x_ref_or_view – The ref to load from.

  • idx – The indexer to use.

  • mask – An optional boolean mask specifying which indices to load. If mask is False and other is not given, no assumptions can be made about the value in the resulting array.

  • other – An optional value to use for indices where mask is False.

  • cache_modifier – TO BE DOCUMENTED.

  • eviction_policy – TO BE DOCUMENTED.

  • volatile – TO BE DOCUMENTED.

Return type:

jax.Array