jax.numpy.trace#

jax.numpy.trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None)[source]#

Calculate sum of the diagonal of input along the given axes.

JAX implementation of numpy.trace().

Parameters:
  • a (ArrayLike) – input array. Must have a.ndim >= 2.

  • offset (int | ArrayLike) – optional, int, default=0. Diagonal offset from the main diagonal. Can be positive or negative.

  • axis1 (int) – optional, default=0. The first axis along which to take the sum of diagonal. Must be a static integer value.

  • axis2 (int) – optional, default=1. The second axis along which to take the sum of diagonal. Must be a static integer value.

  • dtype (DTypeLike | None) – optional. The dtype of the output array. Should be provided as static argument in JIT compilation.

  • out (None) – Not used by JAX.

Returns:

An array of dimension x.ndim-2 containing the sum of the diagonal elements along axes (axis1, axis2)

Return type:

Array

See also

Examples

>>> x = jnp.arange(1, 9).reshape(2, 2, 2)
>>> x
Array([[[1, 2],
        [3, 4]],

       [[5, 6],
        [7, 8]]], dtype=int32)
>>> jnp.trace(x)
Array([ 8, 10], dtype=int32)
>>> jnp.trace(x, offset=1)
Array([3, 4], dtype=int32)
>>> jnp.trace(x, axis1=1, axis2=2)
Array([ 5, 13], dtype=int32)
>>> jnp.trace(x, offset=1, axis1=1, axis2=2)
Array([2, 6], dtype=int32)