jax.numpy.linalg.matrix_norm#
- jax.numpy.linalg.matrix_norm(x, /, *, keepdims=False, ord='fro')[source]#
Compute the norm of a matrix or stack of matrices.
JAX implementation of
numpy.linalg.matrix_norm()
- Parameters:
x (ArrayLike) – array of shape
(..., M, N)
for which to take the norm.keepdims (bool) – if True, keep the reduced dimensions in the output.
ord (str | int) – A string or int specifying the type of norm; default is the Frobenius norm. See
numpy.linalg.norm()
for details on available options.
- Returns:
array containing the norm of
x
. Has shapex.shape[:-2]
ifkeepdims
is False, or shape(..., 1, 1)
ifkeepdims
is True.- Return type:
See also
jax.numpy.linalg.vector_norm()
: Norm of a vector or stack of vectors.jax.numpy.linalg.norm()
: More general matrix or vector norm.
Examples
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6], ... [7, 8, 9]]) >>> jnp.linalg.matrix_norm(x) Array(16.881943, dtype=float32)