jax.scipy.linalg.eigh_tridiagonal#
- jax.scipy.linalg.eigh_tridiagonal(d, e, *, eigvals_only=False, select='a', select_range=None, tol=None)[source]#
Solve the eigenvalue problem for a symmetric real tridiagonal matrix
JAX implementation of
scipy.linalg.eigh_tridiagonal()
.- Parameters:
d (ArrayLike) – real-valued array of shape
(N,)
specifying the diagonal elements.e (ArrayLike) – real-valued array of shape
(N - 1,)
specifying the off-diagonal elements.eigvals_only (bool) – If True, return only the eigenvalues (default: False). Computation of eigenvectors is not yet implemented, so
eigvals_only
must be set to True.select (str) –
specify which eigenvalues to calculate. Supported values are:
'a'
: all eigenvalues'i'
: eigenvalues with indicesselect_range[0] <= i <= select_range[1]
JAX does not currently implement
select = 'v'
.select_range (tuple[float, float] | None) – range of values used when
select='i'
.tol (float | None) – absolute tolerance to use when solving for the eigenvalues.
- Returns:
An array of eigenvalues with shape
(N,)
.- Return type:
See also
jax.scipy.linalg.eigh()
: general Hermitian eigenvalue solverExamples
>>> d = jnp.array([1., 2., 3., 4.]) >>> e = jnp.array([1., 1., 1.]) >>> eigvals = jax.scipy.linalg.eigh_tridiagonal(d, e, eigvals_only=True) >>> eigvals Array([0.2547188, 1.8227171, 3.1772828, 4.745281 ], dtype=float32)
For comparison, we can construct the full matrix and compute the same result using
eigh()
:>>> A = jnp.diag(d) + jnp.diag(e, 1) + jnp.diag(e, -1) >>> A Array([[1., 1., 0., 0.], [1., 2., 1., 0.], [0., 1., 3., 1.], [0., 0., 1., 4.]], dtype=float32) >>> eigvals_full = jax.scipy.linalg.eigh(A, eigvals_only=True) >>> jnp.allclose(eigvals, eigvals_full) Array(True, dtype=bool)