jax.experimental.pallas.load#
- jax.experimental.pallas.load(x_ref_or_view, idx, *, mask=None, other=None, cache_modifier=None, eviction_policy=None, volatile=False)[source]#
Returns an array loaded from the given index.
If neither
mask
norother
is specified, this function has the same semantics asx_ref_or_view[idx]
in JAX.- Parameters:
x_ref_or_view – The ref to load from.
idx – The indexer to use.
mask – An optional boolean mask specifying which indices to load. If mask is
False
andother
is not given, no assumptions can be made about the value in the resulting array.other – An optional value to use for indices where mask is
False
.cache_modifier – TO BE DOCUMENTED.
eviction_policy – TO BE DOCUMENTED.
volatile – TO BE DOCUMENTED.
- Return type: