jax.numpy.poly#
- jax.numpy.poly(seq_of_zeros)[source]#
Returns the coefficients of a polynomial for the given sequence of roots.
JAX implementation of
numpy.poly()
.- Parameters:
seq_of_zeros (ArrayLike) – A scalar or an array of roots of the polynomial of shape
(M,)
or(M, M)
.- Returns:
An array containing the coefficients of the polynomial. The dtype of the output is always promoted to inexact.
- Return type:
Note
jax.numpy.poly()
differs fromnumpy.poly()
:When the input is a scalar,
np.poly
raises aTypeError
, whereasjnp.poly
treats scalars the same as length-1 arrays.For complex-valued or square-shaped inputs,
jnp.poly
always returns complex coefficients, whereasnp.poly
may return real or complex depending on their values.
See also
jax.numpy.polyfit()
: Least squares polynomial fit.jax.numpy.polyval()
: Evaluate a polynomial at specific values.jax.numpy.roots()
: Computes the roots of a polynomial for given coefficients.
Examples
Scalar inputs:
>>> jnp.poly(1) Array([ 1., -1.], dtype=float32)
Input array with integer values:
>>> x = jnp.array([1, 2, 3]) >>> jnp.poly(x) Array([ 1., -6., 11., -6.], dtype=float32)
Input array with complex conjugates:
>>> x = jnp.array([2, 1+2j, 1-2j]) >>> jnp.poly(x) Array([ 1.+0.j, -4.+0.j, 9.+0.j, -10.+0.j], dtype=complex64)
Input array as square matrix with real valued inputs:
>>> x = jnp.array([[2, 1, 5], ... [3, 4, 7], ... [1, 3, 5]]) >>> jnp.round(jnp.poly(x)) Array([ 1.+0.j, -11.-0.j, 9.+0.j, -15.+0.j], dtype=complex64)