jax.numpy.linalg.cond#
- jax.numpy.linalg.cond(x, p=None)[source]#
Compute the condition number of a matrix.
JAX implementation of
numpy.linalg.cond()
.The condition number is defined as
norm(x, p) * norm(inv(x), p)
. Forp = 2
(the default), the condition number is the ratio of the largest to the smallest singular value.- Parameters:
x (ArrayLike) – array of shape
(..., M, N)
for which to compute the condition number.p – the order of the norm to use. One of
{None, 1, -1, 2, -2, inf, -inf, 'fro'}
; seejax.numpy.linalg.norm()
for the meaning of these. The default isp = None
, which is equivalent top = 2
. If not in{None, 2, -2}
thenx
must be square, i.e.M = N
.
- Returns:
array of shape
x.shape[:-2]
containing the condition number.
See also
Examples
Well-conditioned matrix:
>>> x = jnp.array([[1, 2], ... [2, 1]]) >>> jnp.linalg.cond(x) Array(3., dtype=float32)
Ill-conditioned matrix:
>>> x = jnp.array([[1, 2], ... [0, 0]]) >>> jnp.linalg.cond(x) Array(inf, dtype=float32)