jax.numpy.ravel_multi_index#
- jax.numpy.ravel_multi_index(multi_index, dims, mode='raise', order='C')[source]#
Convert multi-dimensional indices into flat indices.
JAX implementation of
numpy.ravel_multi_index()
- Parameters:
multi_index (Sequence[ArrayLike]) – sequence of integer arrays containing indices in each dimension.
dims (Sequence[int]) – sequence of integer sizes; must have
len(dims) == len(multi_index)
mode (str) –
how to handle out-of bound indices. Options are
"raise"
(default): raise a ValueError. This mode is incompatible withjit()
or other JAX transformations."clip"
: clip out-of-bound indices to valid range."wrap"
: wrap out-of-bound indices to valid range.
order (str) –
"C"
(default) or"F"
, specify whether to assume C-style row-major order or Fortran-style column-major order.
- Returns:
array of flattened indices
- Return type:
See also
jax.numpy.unravel_index()
: inverse of this function.Examples
Define a 2-dimensional array and a sequence of indices of even values:
>>> x = jnp.array([[2., 3., 4.], ... [5., 6., 7.]]) >>> indices = jnp.where(x % 2 == 0) >>> indices (Array([0, 0, 1], dtype=int32), Array([0, 2, 1], dtype=int32)) >>> x[indices] Array([2., 4., 6.], dtype=float32)
Compute the flattened indices:
>>> indices_flat = jnp.ravel_multi_index(indices, x.shape) >>> indices_flat Array([0, 2, 4], dtype=int32)
These flattened indices can be used to extract the same values from the flattened
x
array:>>> x_flat = x.ravel() >>> x_flat Array([2., 3., 4., 5., 6., 7.], dtype=float32) >>> x_flat[indices_flat] Array([2., 4., 6.], dtype=float32)
The original indices can be recovered with
unravel_index()
:>>> jnp.unravel_index(indices_flat, x.shape) (Array([0, 0, 1], dtype=int32), Array([0, 2, 1], dtype=int32))