jax.scipy.linalg.solve#
- jax.scipy.linalg.solve(a, b, lower=False, overwrite_a=False, overwrite_b=False, debug=False, check_finite=True, assume_a='gen')[source]#
Solve a linear system of equations.
JAX implementation of
scipy.linalg.solve()
.This solves a (batched) linear system of equations
a @ x = b
forx
givena
andb
.If
a
is singular, this will returnnan
orinf
values.- Parameters:
a (ArrayLike) – array of shape
(..., N, N)
.b (ArrayLike) – array of shape
(..., N)
or(..., N, M)
lower (bool) – Referenced only if
assume_a != 'gen'
. If True, only use the lower triangle of the input, If False (default), only use the upper triangle.assume_a (str) –
specify what properties of
a
can be assumed. Options are:"gen"
: generic matrix (default)"sym"
: symmetric matrix"her"
: hermitian matrix"pos"
: positive-definite matrix
overwrite_a (bool) – unused by JAX
overwrite_b (bool) – unused by JAX
debug (bool) – unused by JAX
check_finite (bool) – unused by JAX
- Returns:
An array of the same shape as
b
containing the solution to the linear system ifa
is non-singular. Ifa
is singular, the result containsnan
orinf
values.- Return type:
See also
jax.scipy.linalg.lu_solve()
: Solve via LU factorization.jax.scipy.linalg.cho_solve()
: Solve via Cholesky factorization.jax.scipy.linalg.solve_triangular()
: Solve a triangular system.jax.numpy.linalg.solve()
: NumPy-style API for solving linear systems.jax.lax.custom_linear_solve()
: matrix-free linear solver.
Examples
A simple 3x3 linear system:
>>> A = jnp.array([[1., 2., 3.], ... [2., 4., 2.], ... [3., 2., 1.]]) >>> b = jnp.array([14., 16., 10.]) >>> x = jax.scipy.linalg.solve(A, b) >>> x Array([1., 2., 3.], dtype=float32)
Confirming that the result solves the system:
>>> jnp.allclose(A @ x, b) Array(True, dtype=bool)