jax.numpy.linalg.eigh#
- jax.numpy.linalg.eigh(a, UPLO=None, symmetrize_input=True)[source]#
Compute the eigenvalues and eigenvectors of a Hermitian matrix.
JAX implementation of
numpy.linalg.eigh()
.- Parameters:
a (ArrayLike) – array of shape
(..., M, M)
, containing the Hermitian (if complex) or symmetric (if real) matrix.UPLO (str | None) – specifies whether the calculation is done with the lower triangular part of
a
('L'
, default) or the upper triangular part ('U'
).symmetrize_input (bool) – if True (default) then input is symmetrized, which leads to better behavior under automatic differentiation. Note that when this is set to True, both the upper and lower triangles of the input will be used in computing the decomposition.
- Returns:
A namedtuple
(eigenvalues, eigenvectors)
whereeigenvalues
: an array of shape(..., M)
containing the eigenvalues, sorted in ascending order.eigenvectors
: an array of shape(..., M, M)
, where columnv[:, i]
is the normalized eigenvector corresponding to the eigenvaluew[i]
.
- Return type:
EighResult
See also
jax.numpy.linalg.eig()
: general eigenvalue decomposition.jax.numpy.linalg.eigvalsh()
: compute eigenvalues only.jax.scipy.linalg.eigh()
: SciPy API for Hermitian eigendecomposition.jax.lax.linalg.eigh()
: XLA API for Hermitian eigendecomposition.
Examples
>>> a = jnp.array([[1, -2j], ... [2j, 1]]) >>> w, v = jnp.linalg.eigh(a) >>> w Array([-1., 3.], dtype=float32) >>> with jnp.printoptions(precision=3): ... v Array([[-0.707+0.j , -0.707+0.j ], [ 0. +0.707j, 0. -0.707j]], dtype=complex64)