jax.numpy.linalg.vector_norm#
- jax.numpy.linalg.vector_norm(x, /, *, axis=None, keepdims=False, ord=2)[source]#
Compute the vector norm of a vector or batch of vectors.
JAX implementation of
numpy.linalg.vector_norm()
.- Parameters:
x (ArrayLike) – N-dimensional array for which to take the norm.
axis (int | tuple[int, ...] | None | None) – optional axis along which to compute the vector norm. If None (default) then
x
is flattened and the norm is taken over all values.keepdims (bool) – if True, keep the reduced dimensions in the output.
ord (int | str) – A string or int specifying the type of norm; default is the 2-norm. See
numpy.linalg.norm()
for details on available options.
- Returns:
array containing the norm of
x
.- Return type:
See also
jax.numpy.linalg.matrix_norm()
: Norm of a matrix or stack of matrices.jax.numpy.linalg.norm()
: More general matrix or vector norm.
Examples
Norm of a single vector:
>>> x = jnp.array([1., 2., 3.]) >>> jnp.linalg.vector_norm(x) Array(3.7416575, dtype=float32)
Norm of a batch of vectors:
>>> x = jnp.array([[1., 2., 3.], ... [4., 5., 7.]]) >>> jnp.linalg.vector_norm(x, axis=1) Array([3.7416575, 9.486833 ], dtype=float32)