jax.numpy.put#
- jax.numpy.put(a, ind, v, mode=None, *, inplace=True)[source]#
Put elements into an array at given indices.
JAX implementation of
numpy.put()
.The semantics of
numpy.put()
are to modify arrays in-place, which is not possible for JAX’s immutable arrays. The JAX version returns a modified copy of the input, and adds theinplace
parameter which must be set to False` by the user as a reminder of this API difference.- Parameters:
a (Array | ndarray | bool | number | bool | int | float | complex) – array into which values will be placed.
ind (Array | ndarray | bool | number | bool | int | float | complex) – array of indices over the flattened array at which to put values.
v (Array | ndarray | bool | number | bool | int | float | complex) – array of values to put into the array.
mode (str | None) –
string specifying how to handle out-of-bound indices. Supported values:
"clip"
(default): clip out-of-bound indices to the final index."wrap"
: wrap out-of-bound indices to the beginning of the array.
inplace (bool) – must be set to False to indicate that the input is not modified in-place, but rather a modified copy is returned.
- Returns:
A copy of
a
with specified entries updated.- Return type:
See also
jax.numpy.place()
: place elements into an array via boolean mask.jax.numpy.ndarray.at()
: array updates using NumPy-style indexing.jax.numpy.take()
: extract values from an array at given indices.
Examples
>>> x = jnp.zeros(5, dtype=int) >>> indices = jnp.array([0, 2, 4]) >>> values = jnp.array([10, 20, 30]) >>> jnp.put(x, indices, values, inplace=False) Array([10, 0, 20, 0, 30], dtype=int32)
This is equivalent to the following
jax.numpy.ndarray.at
indexing syntax:>>> x.at[indices].set(values) Array([10, 0, 20, 0, 30], dtype=int32)
There are two modes for handling out-of-bound indices. By default they are clipped:
>>> indices = jnp.array([0, 2, 6]) >>> jnp.put(x, indices, values, inplace=False, mode='clip') Array([10, 0, 20, 0, 30], dtype=int32)
Alternatively, they can be wrapped to the beginning of the array:
>>> jnp.put(x, indices, values, inplace=False, mode='wrap') Array([10, 30, 20, 0, 0], dtype=int32)
For N-dimensional inputs, the indices refer to the flattened array:
>>> x = jnp.zeros((3, 5), dtype=int) >>> indices = jnp.array([0, 7, 14]) >>> jnp.put(x, indices, values, inplace=False) Array([[10, 0, 0, 0, 0], [ 0, 0, 20, 0, 0], [ 0, 0, 0, 0, 30]], dtype=int32)