jax.numpy.tri#

jax.numpy.tri(N, M=None, k=0, dtype=None)[source]#

Return an array with ones on and below the diagonal and zeros elsewhere.

JAX implementation of numpy.tri()

Parameters:
  • N (int) – int. Dimension of the rows of the returned array.

  • M (int | None | None) – optional, int. Dimension of the columns of the returned array. If not specified, then M = N.

  • k (int) – optional, int, default=0. Specifies the sub-diagonal on and below which the array is filled with ones. k=0 refers to main diagonal, k<0 refers to sub-diagonal below the main diagonal and k>0 refers to sub-diagonal above the main diagonal.

  • dtype (DTypeLike | None | None) – optional, data type of the returned array. The default type is float.

Returns:

An array of shape (N, M) containing the lower triangle with elements below the sub-diagonal specified by k are set to one and zero elsewhere.

Return type:

Array

See also

Examples

>>> jnp.tri(3)
Array([[1., 0., 0.],
       [1., 1., 0.],
       [1., 1., 1.]], dtype=float32)

When M is not equal to N:

>>> jnp.tri(3, 4)
Array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.]], dtype=float32)

when k>0:

>>> jnp.tri(3, k=1)
Array([[1., 1., 0.],
       [1., 1., 1.],
       [1., 1., 1.]], dtype=float32)

When k<0:

>>> jnp.tri(3, 4, k=-1)
Array([[0., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 1., 0., 0.]], dtype=float32)