jax.scipy.signal.detrend#
- jax.scipy.signal.detrend(data, axis=-1, type='linear', bp=0, overwrite_data=None)[source]#
Remove linear or piecewise linear trends from data.
JAX implementation of
scipy.signal.detrend()
.- Parameters:
data (ArrayLike) – The input array containing the data to detrend.
axis (int) – The axis along which to detrend. Default is -1 (the last axis).
type (str) –
The type of detrending. Can be:
'linear'
: Fit a single linear trend for the entire data.'constant'
: Remove the mean value of the data.
bp (int) – A sequence of breakpoints. If given, piecewise linear trends are fit between these breakpoints.
overwrite_data (None | None) – This argument is not supported by JAX’s implementation.
- Returns:
The detrended data array.
- Return type:
Examples
A simple detrend operation in one dimension:
>>> data = jnp.array([1., 4., 8., 8., 9.])
Removing a linear trend from the data:
>>> detrended = jax.scipy.signal.detrend(data) >>> with jnp.printoptions(precision=3, suppress=True): # suppress float error ... print("Detrended:", detrended) ... print("Underlying trend:", data - detrended) Detrended: [-1. -0. 2. -0. -1.] Underlying trend: [ 2. 4. 6. 8. 10.]
Removing a constant trend from the data:
>>> detrended = jax.scipy.signal.detrend(data, type='constant') >>> with jnp.printoptions(precision=3): # suppress float error ... print("Detrended:", detrended) ... print("Underlying trend:", data - detrended) Detrended: [-5. -2. 2. 2. 3.] Underlying trend: [6. 6. 6. 6. 6.]