jax.scipy.signal.detrend#

jax.scipy.signal.detrend(data, axis=-1, type='linear', bp=0, overwrite_data=None)[source]#

Remove linear or piecewise linear trends from data.

JAX implementation of scipy.signal.detrend().

Parameters:
  • data (ArrayLike) – The input array containing the data to detrend.

  • axis (int) – The axis along which to detrend. Default is -1 (the last axis).

  • type (str) –

    The type of detrending. Can be:

    • 'linear': Fit a single linear trend for the entire data.

    • 'constant': Remove the mean value of the data.

  • bp (int) – A sequence of breakpoints. If given, piecewise linear trends are fit between these breakpoints.

  • overwrite_data (None | None) – This argument is not supported by JAX’s implementation.

Returns:

The detrended data array.

Return type:

Array

Examples

A simple detrend operation in one dimension:

>>> data = jnp.array([1., 4., 8., 8., 9.])

Removing a linear trend from the data:

>>> detrended = jax.scipy.signal.detrend(data)
>>> with jnp.printoptions(precision=3, suppress=True):  # suppress float error
...   print("Detrended:", detrended)
...   print("Underlying trend:", data - detrended)
Detrended: [-1. -0.  2. -0. -1.]
Underlying trend: [ 2.  4.  6.  8. 10.]

Removing a constant trend from the data:

>>> detrended = jax.scipy.signal.detrend(data, type='constant')
>>> with jnp.printoptions(precision=3):  # suppress float error
...   print("Detrended:", detrended)
...   print("Underlying trend:", data - detrended)
Detrended: [-5. -2.  2.  2.  3.]
Underlying trend: [6. 6. 6. 6. 6.]