jax.numpy.cumulative_prod#
- jax.numpy.cumulative_prod(x, /, *, axis=None, dtype=None, include_initial=False)[source]#
Cumulative product along the axis of an array.
JAX implementation of
numpy.cumulative_prod()
.- Parameters:
x (ArrayLike) – N-dimensional array
axis (int | None | None) – integer axis along which to accumulate. If
x
is one-dimensional, this argument is optional and defaults to zero.dtype (DTypeLike | None | None) – optional dtype of the output.
include_initial (bool) – if True, then include the initial value in the cumulative product. Default is False.
- Returns:
An array containing the accumulated values.
- Return type:
See also
jax.numpy.cumprod()
: alternative API for cumulative product.jax.numpy.nancumprod()
: cumulative product while ignoring NaN values.jax.numpy.multiply.accumulate()
: cumulative product via the ufunc API.
Examples
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.cumulative_prod(x, axis=1) Array([[ 1, 2, 6], [ 4, 20, 120]], dtype=int32) >>> jnp.cumulative_prod(x, axis=1, include_initial=True) Array([[ 1, 1, 2, 6], [ 1, 4, 20, 120]], dtype=int32)