jax.numpy.linalg.trace#
- jax.numpy.linalg.trace(x, /, *, offset=0, dtype=None)[source]#
Compute the trace of a matrix.
JAX implementation of
numpy.linalg.trace()
.- Parameters:
x (ArrayLike) – array of shape
(..., M, N)
and whose innermost two dimensions form MxN matrices for which to take the trace.offset (int) – positive or negative offset from the main diagonal (default: 0).
dtype (DTypeLike | None | None) – data type of the returned array (default:
None
). IfNone
, then output dtype will match the dtype ofx
, promoted to default precision in the case of integer types.
- Returns:
array of batched traces with shape
x.shape[:-2]
- Return type:
See also
jax.numpy.trace()
: similar API in thejax.numpy
namespace.
Examples
Trace of a single matrix:
>>> x = jnp.array([[1, 2, 3, 4], ... [5, 6, 7, 8], ... [9, 10, 11, 12]]) >>> jnp.linalg.trace(x) Array(18, dtype=int32) >>> jnp.linalg.trace(x, offset=1) Array(21, dtype=int32) >>> jnp.linalg.trace(x, offset=-1, dtype="float32") Array(15., dtype=float32)
Batched traces:
>>> x = jnp.arange(24).reshape(2, 3, 4) >>> jnp.linalg.trace(x) Array([15, 51], dtype=int32)