jax.scipy.linalg.svd#
- jax.scipy.linalg.svd(a: ArrayLike, full_matrices: bool = True, compute_uv: Literal[True] = True, overwrite_a: bool = False, check_finite: bool = True, lapack_driver: str = 'gesdd') tuple[Array, Array, Array] [source]#
- jax.scipy.linalg.svd(a: ArrayLike, full_matrices: bool, compute_uv: Literal[False], overwrite_a: bool = False, check_finite: bool = True, lapack_driver: str = 'gesdd') Array
- jax.scipy.linalg.svd(a: ArrayLike, full_matrices: bool = True, *, compute_uv: Literal[False], overwrite_a: bool = False, check_finite: bool = True, lapack_driver: str = 'gesdd') Array
- jax.scipy.linalg.svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True, overwrite_a: bool = False, check_finite: bool = True, lapack_driver: str = 'gesdd') Array | tuple[Array, Array, Array]
Compute the singular value decomposition.
JAX implementation of
scipy.linalg.svd()
.The SVD of a matrix A is given by
\[A = U\Sigma V^H\]\(U\) contains the left singular vectors and satisfies \(U^HU=I\)
\(V\) contains the right singular vectors and satisfies \(V^HV=I\)
\(\Sigma\) is a diagonal matrix of singular values.
- Parameters:
a – input array, of shape
(..., N, M)
full_matrices – if True (default) compute the full matrices; i.e.
u
andvh
have shape(..., N, N)
and(..., M, M)
. If False, then the shapes are(..., N, K)
and(..., K, M)
withK = min(N, M)
.compute_uv – if True (default), return the full SVD
(u, s, vh)
. If False then return only the singular valuess
.overwrite_a – unused by JAX
check_finite – unused by JAX
lapack_driver – unused by JAX. If you want to select a non-default SVD driver, please check
jax.lax.linalg.svd()
which provides such functionality.
- Returns:
A tuple of arrays
(u, s, vh)
ifcompute_uv
is True, otherwise the arrays
.u
: left singular vectors of shape(..., N, N)
iffull_matrices
is True or(..., N, K)
otherwise.s
: singular values of shape(..., K)
vh
: conjugate-transposed right singular vectors of shape(..., M, M)
iffull_matrices
is True or(..., K, M)
otherwise.
where
K = min(N, M)
.
See also
jax.numpy.linalg.svd()
: NumPy-style SVD APIjax.lax.linalg.svd()
: XLA-style SVD API
Examples
Consider the SVD of a small real-valued array:
>>> x = jnp.array([[1., 2., 3.], ... [6., 5., 4.]]) >>> u, s, vt = jax.scipy.linalg.svd(x, full_matrices=False) >>> s Array([9.361919 , 1.8315067], dtype=float32)
The singular vectors are in the columns of
u
andv = vt.T
. These vectors are orthonormal, which can be demonstrated by comparing the matrix product with the identity matrix:>>> jnp.allclose(u.T @ u, jnp.eye(2), atol=1E-5) Array(True, dtype=bool) >>> v = vt.T >>> jnp.allclose(v.T @ v, jnp.eye(2), atol=1E-5) Array(True, dtype=bool)
Given the SVD,
x
can be reconstructed via matrix multiplication:>>> x_reconstructed = u @ jnp.diag(s) @ vt >>> jnp.allclose(x_reconstructed, x) Array(True, dtype=bool)