jax.numpy.cumsum#
- jax.numpy.cumsum(a, axis=None, dtype=None, out=None)[source]#
Cumulative sum of elements along an axis.
JAX implementation of
numpy.cumsum()
.- Parameters:
a (ArrayLike) – N-dimensional array to be accumulated.
axis (int | None) – integer axis along which to accumulate. If None (default), then array will be flattened and accumulated along the flattened axis.
dtype (DTypeLike | None) – optionally specify the dtype of the output. If not specified, then the output dtype will match the input dtype.
out (None) – unused by JAX
- Returns:
An array containing the accumulated sum along the given axis.
- Return type:
See also
jax.numpy.cumulative_sum()
: cumulative sum via the array API standard.jax.numpy.add.accumulate()
: cumulative sum via ufunc methods.jax.numpy.nancumsum()
: cumulative sum ignoring NaN values.jax.numpy.sum()
: sum along axis
Examples
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.cumsum(x) # flattened cumulative sum Array([ 1, 3, 6, 10, 15, 21], dtype=int32) >>> jnp.cumsum(x, axis=1) # cumulative sum along axis 1 Array([[ 1, 3, 6], [ 4, 9, 15]], dtype=int32)