jax.numpy.nanargmin#
- jax.numpy.nanargmin(a, axis=None, out=None, keepdims=None)[source]#
Return the index of the minimum value of an array, ignoring NaNs.
JAX implementation of
numpy.nanargmin()
.- Parameters:
a (ArrayLike) – input array
axis (int | None | None) – optional integer specifying the axis along which to find the maximum value. If
axis
is not specified,a
will be flattened.out (None | None) – unused by JAX
keepdims (bool | None | None) – if True, then return an array with the same number of dimensions as
a
.
- Returns:
an array containing the index of the minimum value along the specified axis.
- Return type:
Note
In the case of an axis with all-NaN values, the returned index will be -1. This differs from the behavior of
numpy.nanargmin()
, which raises an error.See also
jax.numpy.argmin()
: return the index of the minimum value.jax.numpy.nanargmax()
: computeargmax
while ignoring NaN values.
Examples
>>> x = jnp.array([jnp.nan, 3, 5, 4, 2]) >>> jnp.nanargmin(x) Array(4, dtype=int32)
>>> x = jnp.array([[1, 3, jnp.nan], ... [5, 4, jnp.nan]]) >>> jnp.nanargmin(x, axis=1) Array([0, 1], dtype=int32)
>>> jnp.nanargmin(x, axis=1, keepdims=True) Array([[0], [1]], dtype=int32)