jax.scipy.linalg.lu#
- jax.scipy.linalg.lu(a: ArrayLike, permute_l: Literal[False] = False, overwrite_a: bool = False, check_finite: bool = True) tuple[Array, Array, Array] [source]#
- jax.scipy.linalg.lu(a: ArrayLike, permute_l: Literal[True], overwrite_a: bool = False, check_finite: bool = True) tuple[Array, Array]
- jax.scipy.linalg.lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False, check_finite: bool = True) tuple[Array, Array] | tuple[Array, Array, Array]
Compute the LU decomposition
JAX implementation of
scipy.linalg.lu()
.The LU decomposition of a matrix A is:
\[A = P L U\]where P is a permutation matrix, L is lower-triangular and U is upper-triangular.
- Parameters:
a – array of shape
(..., M, N)
to decompose.permute_l – if True, then permute
L
and return(P @ L, U)
(default: False)overwrite_a – not used by JAX
check_finite – not used by JAX
- Returns:
P
is a permutation matrix of shape(..., M, M)
L
is a lower-triangular matrix of shape(... M, K)
U
is an upper-triangular matrix of shape(..., K, N)
with
K = min(M, N)
- Return type:
A tuple of arrays
(P @ L, U)
ifpermute_l
is True, else(P, L, U)
See also
jax.numpy.linalg.lu()
: NumPy-style API for LU decomposition.jax.lax.linalg.lu()
: XLA-style API for LU decomposition.jax.scipy.linalg.lu_solve()
: LU-based linear solver.
Examples
An LU decomposition of a 3x3 matrix:
>>> a = jnp.array([[1., 2., 3.], ... [5., 4., 2.], ... [3., 2., 1.]]) >>> P, L, U = jax.scipy.linalg.lu(a)
P
is a permutation matrix: i.e. each row and column has a single1
:>>> P Array([[0., 1., 0.], [1., 0., 0.], [0., 0., 1.]], dtype=float32)
L
andU
are lower-triangular and upper-triangular matrices:>>> with jnp.printoptions(precision=3): ... print(L) ... print(U) [[ 1. 0. 0. ] [ 0.2 1. 0. ] [ 0.6 -0.333 1. ]] [[5. 4. 2. ] [0. 1.2 2.6 ] [0. 0. 0.667]]
The original matrix can be reconstructed by multiplying the three together:
>>> a_reconstructed = P @ L @ U >>> jnp.allclose(a, a_reconstructed) Array(True, dtype=bool)