jax.numpy.tri#
- jax.numpy.tri(N, M=None, k=0, dtype=None)[source]#
Return an array with ones on and below the diagonal and zeros elsewhere.
JAX implementation of
numpy.tri()
- Parameters:
N (int) – int. Dimension of the rows of the returned array.
M (int | None | None) – optional, int. Dimension of the columns of the returned array. If not specified, then
M = N
.k (int) – optional, int, default=0. Specifies the sub-diagonal on and below which the array is filled with ones.
k=0
refers to main diagonal,k<0
refers to sub-diagonal below the main diagonal andk>0
refers to sub-diagonal above the main diagonal.dtype (DTypeLike | None | None) – optional, data type of the returned array. The default type is float.
- Returns:
An array of shape
(N, M)
containing the lower triangle with elements below the sub-diagonal specified byk
are set to one and zero elsewhere.- Return type:
See also
jax.numpy.tril()
: Returns a lower triangle of an array.jax.numpy.triu()
: Returns an upper triangle of an array.
Examples
>>> jnp.tri(3) Array([[1., 0., 0.], [1., 1., 0.], [1., 1., 1.]], dtype=float32)
When
M
is not equal toN
:>>> jnp.tri(3, 4) Array([[1., 0., 0., 0.], [1., 1., 0., 0.], [1., 1., 1., 0.]], dtype=float32)
when
k>0
:>>> jnp.tri(3, k=1) Array([[1., 1., 0.], [1., 1., 1.], [1., 1., 1.]], dtype=float32)
When
k<0
:>>> jnp.tri(3, 4, k=-1) Array([[0., 0., 0., 0.], [1., 0., 0., 0.], [1., 1., 0., 0.]], dtype=float32)