jax.numpy.linalg.solve#

jax.numpy.linalg.solve(a, b)[source]#

Solve a linear system of equations.

JAX implementation of numpy.linalg.solve().

This solves a (batched) linear system of equations a @ x = b for x given a and b.

If a is singular, this will return nan or inf values.

Parameters:
  • a (ArrayLike) – array of shape (..., N, N).

  • b (ArrayLike) – array of shape (N,) (for 1-dimensional right-hand-side) or (..., N, M) (for batched 2-dimensional right-hand-side).

Returns:

An array containing the result of the linear solve if a is non-singular. The result has shape (..., N) if b is of shape (N,), and has shape (..., N, M) otherwise. If a is singular, the result contains nan or inf values.

Return type:

Array

See also

Examples

A simple 3x3 linear system:

>>> A = jnp.array([[1., 2., 3.],
...                [2., 4., 2.],
...                [3., 2., 1.]])
>>> b = jnp.array([14., 16., 10.])
>>> x = jnp.linalg.solve(A, b)
>>> x
Array([1., 2., 3.], dtype=float32)

Confirming that the result solves the system:

>>> jnp.allclose(A @ x, b)
Array(True, dtype=bool)