jax.numpy.linalg.cholesky#
- jax.numpy.linalg.cholesky(a, *, upper=False, symmetrize_input=True)[source]#
Compute the Cholesky decomposition of a matrix.
JAX implementation of
numpy.linalg.cholesky()
.The Cholesky decomposition of a matrix A is:
\[A = U^HU\]or
\[A = LL^H\]where U is an upper-triangular matrix and L is a lower-triangular matrix, and \(X^H\) is the Hermitian transpose of X.
- Parameters:
a (ArrayLike) – input array, representing a (batched) positive-definite hermitian matrix. Must have shape
(..., N, N)
.upper (bool) – if True, compute the upper Cholesky decomposition U. if False (default), compute the lower Cholesky decomposition L.
symmetrize_input (bool) – if True (default) then input is symmetrized, which leads to better behavior under automatic differentiation. Note that when this is set to True, both the upper and lower triangles of the input will be used in computing the decomposition.
- Returns:
array of shape
(..., N, N)
representing the Cholesky decomposition of the input. If the input is not Hermitian positive-definite, The result will contain NaN entries.- Return type:
See also
jax.scipy.linalg.cholesky()
: SciPy-style Cholesky APIjax.lax.linalg.cholesky()
: XLA-style Cholesky API
Examples
A small real Hermitian positive-definite matrix:
>>> x = jnp.array([[2., 1.], ... [1., 2.]])
Lower Cholesky factorization:
>>> jnp.linalg.cholesky(x) Array([[1.4142135 , 0. ], [0.70710677, 1.2247449 ]], dtype=float32)
Upper Cholesky factorization:
>>> jnp.linalg.cholesky(x, upper=True) Array([[1.4142135 , 0.70710677], [0. , 1.2247449 ]], dtype=float32)
Reconstructing
x
from its factorization:>>> L = jnp.linalg.cholesky(x) >>> jnp.allclose(x, L @ L.T) Array(True, dtype=bool)