jax.scipy.signal.fftconvolve#
- jax.scipy.signal.fftconvolve(in1, in2, mode='full', axes=None)[source]#
Convolve two N-dimensional arrays using Fast Fourier Transform (FFT).
JAX implementation of
scipy.signal.fftconvolve()
.- Parameters:
in1 (ArrayLike) – left-hand input to the convolution.
in2 (ArrayLike) – right-hand input to the convolution. Must have
in1.ndim == in2.ndim
.mode (str) –
controls the size of the output. Available operations are:
"full"
: (default) output the full convolution of the inputs."same"
: return a centered portion of the"full"
output which is the same size asin1
."valid"
: return the portion of the"full"
output which do not depend on padding at the array edges.
axes (Sequence[int] | None | None) – optional sequence of axes along which to apply the convolution.
- Returns:
Array containing the convolved result.
- Return type:
See also
jax.numpy.convolve()
: 1D convolutionjax.scipy.signal.convolve()
: direct convolution
Examples
A few 1D convolution examples. Because FFT-based convolution is approximate, We use
jax.numpy.printoptions()
below to adjust the printing precision:>>> x = jnp.array([1, 2, 3, 2, 1]) >>> y = jnp.array([1, 1, 1])
Full convolution uses implicit zero-padding at the edges:
>>> with jax.numpy.printoptions(precision=3): ... print(jax.scipy.signal.fftconvolve(x, y, mode='full')) [1. 3. 6. 7. 6. 3. 1.]
Specifying
mode = 'same'
returns a centered convolution the same size as the first input:>>> with jax.numpy.printoptions(precision=3): ... print(jax.scipy.signal.fftconvolve(x, y, mode='same')) [3. 6. 7. 6. 3.]
Specifying
mode = 'valid'
returns only the portion where the two arrays fully overlap:>>> with jax.numpy.printoptions(precision=3): ... print(jax.scipy.signal.fftconvolve(x, y, mode='valid')) [6. 7. 6.]