jax.numpy.vander#

jax.numpy.vander(x, N=None, increasing=False)[source]#

Generate a Vandermonde matrix.

JAX implementation of numpy.vander().

Parameters:
  • x (ArrayLike) – input array. Must have x.ndim == 1.

  • N (int | None) – int, optional, default=None. Specifies the number of the columns the output matrix. If not specified, N = len(x).

  • increasing (bool) – bool, optional, default=False. Specifies the order of the powers of the columns. If True, the powers increase from left to right, \([x^0, x^1, ..., x^{(N-1)}]\). By default, the powers decrease from left to right \([x^{(N-1)}, ..., x^1, x^0]\).

Returns:

An array of shape [len(x), N] containing the generated Vandermonde matrix.

Return type:

Array

Examples

>>> x = jnp.array([1, 2, 3, 4])
>>> jnp.vander(x)
Array([[ 1,  1,  1,  1],
       [ 8,  4,  2,  1],
       [27,  9,  3,  1],
       [64, 16,  4,  1]], dtype=int32)

If N = 2, generates a Vandermonde matrix with 2 columns.

>>> jnp.vander(x, N=2)
Array([[1, 1],
       [2, 1],
       [3, 1],
       [4, 1]], dtype=int32)

Generates the Vandermonde matrix in increaing order of powers, when increasing=True.

>>> jnp.vander(x, increasing=True)
Array([[ 1,  1,  1,  1],
       [ 1,  2,  4,  8],
       [ 1,  3,  9, 27],
       [ 1,  4, 16, 64]], dtype=int32)