jax.numpy.linalg.eigh#

jax.numpy.linalg.eigh(a, UPLO=None, symmetrize_input=True)[source]#

Compute the eigenvalues and eigenvectors of a Hermitian matrix.

JAX implementation of numpy.linalg.eigh().

Parameters:
  • a (ArrayLike) – array of shape (..., M, M), containing the Hermitian (if complex) or symmetric (if real) matrix.

  • UPLO (str | None) – specifies whether the calculation is done with the lower triangular part of a ('L', default) or the upper triangular part ('U').

  • symmetrize_input (bool) – if True (default) then input is symmetrized, which leads to better behavior under automatic differentiation. Note that when this is set to True, both the upper and lower triangles of the input will be used in computing the decomposition.

Returns:

A namedtuple (eigenvalues, eigenvectors) where

  • eigenvalues: an array of shape (..., M) containing the eigenvalues, sorted in ascending order.

  • eigenvectors: an array of shape (..., M, M), where column v[:, i] is the normalized eigenvector corresponding to the eigenvalue w[i].

Return type:

EighResult

See also

Examples

>>> a = jnp.array([[1, -2j],
...                [2j, 1]])
>>> w, v = jnp.linalg.eigh(a)
>>> w
Array([-1.,  3.], dtype=float32)
>>> with jnp.printoptions(precision=3):
...   v
Array([[-0.707+0.j   , -0.707+0.j   ],
       [ 0.   +0.707j,  0.   -0.707j]], dtype=complex64)