jax.scipy.linalg.block_diag#
- jax.scipy.linalg.block_diag(*arrs)[source]#
Create a block diagonal matrix from input arrays.
JAX implementation of
scipy.linalg.block_diag()
.- Parameters:
*arrs (ArrayLike) – arrays of at most two dimensions
- Returns:
2D block-diagonal array constructed by placing the input arrays along the diagonal.
- Return type:
Examples
>>> A = jnp.ones((1, 1)) >>> B = jnp.ones((2, 2)) >>> C = jnp.ones((3, 3)) >>> jax.scipy.linalg.block_diag(A, B, C) Array([[1., 0., 0., 0., 0., 0.], [0., 1., 1., 0., 0., 0.], [0., 1., 1., 0., 0., 0.], [0., 0., 0., 1., 1., 1.], [0., 0., 0., 1., 1., 1.], [0., 0., 0., 1., 1., 1.]], dtype=float32)