jax.numpy.diagflat#
- jax.numpy.diagflat(v, k=0)[source]#
Return a 2-D array with the flattened input array laid out on the diagonal.
JAX implementation of
numpy.diagflat()
.This differs from np.diagflat for some scalar values of v. JAX always returns a two-dimensional array, whereas NumPy may return a scalar depending on the type of v.
- Parameters:
v (ArrayLike) – Input array. Can be N-dimensional but is flattened to 1D.
k (int) – optional, default=0. Diagonal offset. Positive values place the diagonal above the main diagonal, negative values place it below the main diagonal.
- Returns:
A 2D array with the input elements placed along the diagonal with the specified offset (k). The remaining entries are filled with zeros.
- Return type:
See also
Examples
>>> jnp.diagflat(jnp.array([1, 2, 3])) Array([[1, 0, 0], [0, 2, 0], [0, 0, 3]], dtype=int32) >>> jnp.diagflat(jnp.array([1, 2, 3]), k=1) Array([[0, 1, 0, 0], [0, 0, 2, 0], [0, 0, 0, 3], [0, 0, 0, 0]], dtype=int32) >>> a = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.diagflat(a) Array([[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]], dtype=int32)