jax.numpy.trace#
- jax.numpy.trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None)[source]#
Calculate sum of the diagonal of input along the given axes.
JAX implementation of
numpy.trace()
.- Parameters:
a (ArrayLike) – input array. Must have
a.ndim >= 2
.offset (int | ArrayLike) – optional, int, default=0. Diagonal offset from the main diagonal. Can be positive or negative.
axis1 (int) – optional, default=0. The first axis along which to take the sum of diagonal. Must be a static integer value.
axis2 (int) – optional, default=1. The second axis along which to take the sum of diagonal. Must be a static integer value.
dtype (DTypeLike | None) – optional. The dtype of the output array. Should be provided as static argument in JIT compilation.
out (None) – Not used by JAX.
- Returns:
An array of dimension x.ndim-2 containing the sum of the diagonal elements along axes (axis1, axis2)
- Return type:
See also
jax.numpy.diag()
: Returns the specified diagonal or constructs a diagonal arrayjax.numpy.diagonal()
: Returns the specified diagonal of an array.jax.numpy.diagflat()
: Returns a 2-D array with the flattened input array laid out on the diagonal.
Examples
>>> x = jnp.arange(1, 9).reshape(2, 2, 2) >>> x Array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=int32) >>> jnp.trace(x) Array([ 8, 10], dtype=int32) >>> jnp.trace(x, offset=1) Array([3, 4], dtype=int32) >>> jnp.trace(x, axis1=1, axis2=2) Array([ 5, 13], dtype=int32) >>> jnp.trace(x, offset=1, axis1=1, axis2=2) Array([2, 6], dtype=int32)