jax.numpy.append#
- jax.numpy.append(arr, values, axis=None)[source]#
Return a new array with values appended to the end of the original array.
JAX implementation of
numpy.append()
.- Parameters:
arr (ArrayLike) – original array.
values (ArrayLike) – values to be appended to the array. The
values
must have the same number of dimensions asarr
, and all dimensions must match except in the specified axis.axis (int | None) – axis along which to append values. If None (default), both
arr
andvalues
will be flattened before appending.
- Returns:
A new array with values appended to
arr
.- Return type:
See also
Examples
>>> a = jnp.array([1, 2, 3]) >>> b = jnp.array([4, 5, 6]) >>> jnp.append(a, b) Array([1, 2, 3, 4, 5, 6], dtype=int32)
Appending along a specific axis:
>>> a = jnp.array([[1, 2], ... [3, 4]]) >>> b = jnp.array([[5, 6]]) >>> jnp.append(a, b, axis=0) Array([[1, 2], [3, 4], [5, 6]], dtype=int32)
Appending along a trailing axis:
>>> a = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> b = jnp.array([[7], [8]]) >>> jnp.append(a, b, axis=1) Array([[1, 2, 3, 7], [4, 5, 6, 8]], dtype=int32)