jax.numpy.linalg.cholesky#

jax.numpy.linalg.cholesky(a, *, upper=False, symmetrize_input=True)[source]#

Compute the Cholesky decomposition of a matrix.

JAX implementation of numpy.linalg.cholesky().

The Cholesky decomposition of a matrix A is:

\[A = U^HU\]

or

\[A = LL^H\]

where U is an upper-triangular matrix and L is a lower-triangular matrix, and \(X^H\) is the Hermitian transpose of X.

Parameters:
  • a (ArrayLike) – input array, representing a (batched) positive-definite hermitian matrix. Must have shape (..., N, N).

  • upper (bool) – if True, compute the upper Cholesky decomposition U. if False (default), compute the lower Cholesky decomposition L.

  • symmetrize_input (bool) – if True (default) then input is symmetrized, which leads to better behavior under automatic differentiation. Note that when this is set to True, both the upper and lower triangles of the input will be used in computing the decomposition.

Returns:

array of shape (..., N, N) representing the Cholesky decomposition of the input. If the input is not Hermitian positive-definite, The result will contain NaN entries.

Return type:

Array

See also

Examples

A small real Hermitian positive-definite matrix:

>>> x = jnp.array([[2., 1.],
...                [1., 2.]])

Lower Cholesky factorization:

>>> jnp.linalg.cholesky(x)
Array([[1.4142135 , 0.        ],
       [0.70710677, 1.2247449 ]], dtype=float32)

Upper Cholesky factorization:

>>> jnp.linalg.cholesky(x, upper=True)
Array([[1.4142135 , 0.70710677],
       [0.        , 1.2247449 ]], dtype=float32)

Reconstructing x from its factorization:

>>> L = jnp.linalg.cholesky(x)
>>> jnp.allclose(x, L @ L.T)
Array(True, dtype=bool)