jax.numpy.put#

jax.numpy.put(a, ind, v, mode=None, *, inplace=True)[source]#

Put elements into an array at given indices.

JAX implementation of numpy.put().

The semantics of numpy.put() are to modify arrays in-place, which is not possible for JAX’s immutable arrays. The JAX version returns a modified copy of the input, and adds the inplace parameter which must be set to False` by the user as a reminder of this API difference.

Parameters:
  • a (Array | ndarray | bool | number | bool | int | float | complex) – array into which values will be placed.

  • ind (Array | ndarray | bool | number | bool | int | float | complex) – array of indices over the flattened array at which to put values.

  • v (Array | ndarray | bool | number | bool | int | float | complex) – array of values to put into the array.

  • mode (str | None) –

    string specifying how to handle out-of-bound indices. Supported values:

    • "clip" (default): clip out-of-bound indices to the final index.

    • "wrap": wrap out-of-bound indices to the beginning of the array.

  • inplace (bool) – must be set to False to indicate that the input is not modified in-place, but rather a modified copy is returned.

Returns:

A copy of a with specified entries updated.

Return type:

Array

See also

Examples

>>> x = jnp.zeros(5, dtype=int)
>>> indices = jnp.array([0, 2, 4])
>>> values = jnp.array([10, 20, 30])
>>> jnp.put(x, indices, values, inplace=False)
Array([10,  0, 20,  0, 30], dtype=int32)

This is equivalent to the following jax.numpy.ndarray.at indexing syntax:

>>> x.at[indices].set(values)
Array([10,  0, 20,  0, 30], dtype=int32)

There are two modes for handling out-of-bound indices. By default they are clipped:

>>> indices = jnp.array([0, 2, 6])
>>> jnp.put(x, indices, values, inplace=False, mode='clip')
Array([10,  0, 20,  0, 30], dtype=int32)

Alternatively, they can be wrapped to the beginning of the array:

>>> jnp.put(x, indices, values, inplace=False, mode='wrap')
Array([10,  30, 20,  0, 0], dtype=int32)

For N-dimensional inputs, the indices refer to the flattened array:

>>> x = jnp.zeros((3, 5), dtype=int)
>>> indices = jnp.array([0, 7, 14])
>>> jnp.put(x, indices, values, inplace=False)
Array([[10,  0,  0,  0,  0],
       [ 0,  0, 20,  0,  0],
       [ 0,  0,  0,  0, 30]], dtype=int32)