jax.scipy.linalg.lu_solve#
- jax.scipy.linalg.lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True)[source]#
Solve a linear system using an LU factorization
JAX implementation of
scipy.linalg.lu_solve()
. Uses the output ofjax.scipy.linalg.lu_factor()
.- Parameters:
lu_and_piv (tuple[Array, ArrayLike]) –
(lu, piv)
, output oflu_factor()
.lu
is an array of shape(..., M, N)
, containingL
in its lower triangle andU
in its upper.piv
is an array of shape(..., K)
, withK = min(M, N)
, which encodes the pivots.b (ArrayLike) – right-hand-side of linear system. Must have shape
(..., M)
trans (int) –
type of system to solve. Options are:
0
: \(A x = b\)1
: \(A^Tx = b\)2
: \(A^Hx = b\)
overwrite_b (bool) – unused by JAX
check_finite (bool) – unused by JAX
- Returns:
Array of shape
(..., N)
representing the solution of the linear system.- Return type:
Examples
Solving a small linear system via LU factorization:
>>> a = jnp.array([[2., 1.], ... [1., 2.]])
Compute the lu factorization via
lu_factor()
, and use it to solve a linear equation vialu_solve()
.>>> b = jnp.array([3., 4.]) >>> lufac = jax.scipy.linalg.lu_factor(a) >>> y = jax.scipy.linalg.lu_solve(lufac, b) >>> y Array([0.6666666, 1.6666667], dtype=float32)
Check that the result is consistent:
>>> jnp.allclose(a @ y, b) Array(True, dtype=bool)