jax.numpy.unravel_index#
- jax.numpy.unravel_index(indices, shape)[source]#
Convert flat indices into multi-dimensional indices.
JAX implementation of
numpy.unravel_index()
. The JAX version differs in its treatment of out-of-bound indices: unlike NumPy, negative indices are supported, and out-of-bound indices are clipped to the nearest valid value.- Parameters:
indices (ArrayLike) – integer array of flat indices
shape (Shape) – shape of multidimensional array to index into
- Returns:
Tuple of unraveled indices
- Return type:
See also
jax.numpy.ravel_multi_index()
: Inverse of this function.Examples
Start with a 1D array values and indices:
>>> x = jnp.array([2., 3., 4., 5., 6., 7.]) >>> indices = jnp.array([1, 3, 5]) >>> print(x[indices]) [3. 5. 7.]
Now if
x
is reshaped,unravel_indices
can be used to convert the flat indices into a tuple of indices that access the same entries:>>> shape = (2, 3) >>> x_2D = x.reshape(shape) >>> indices_2D = jnp.unravel_index(indices, shape) >>> indices_2D (Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32)) >>> print(x_2D[indices_2D]) [3. 5. 7.]
The inverse function,
ravel_multi_index
, can be used to obtain the original indices:>>> jnp.ravel_multi_index(indices_2D, shape) Array([1, 3, 5], dtype=int32)