jax.scipy.interpolate.RegularGridInterpolator#
- class jax.scipy.interpolate.RegularGridInterpolator(points, values, method='linear', bounds_error=False, fill_value=nan)[source]#
Interpolate points on a regular rectangular grid.
JAX implementation of
scipy.interpolate.RegularGridInterpolator()
.- Parameters:
points – length-N sequence of arrays specifying the grid coordinates.
values – N-dimensional array specifying the grid values.
method – interpolation method, either
"linear"
or"nearest"
.bounds_error – not implemented by JAX
fill_value – value returned for points outside the grid, defaults to NaN.
- Returns:
callable interpolation object.
- Return type:
interpolator
Examples
>>> points = (jnp.array([1, 2, 3]), jnp.array([4, 5, 6])) >>> values = jnp.array([[10, 20, 30], [40, 50, 60], [70, 80, 90]]) >>> interpolate = RegularGridInterpolator(points, values, method='linear')
>>> query_points = jnp.array([[1.5, 4.5], [2.2, 5.8]]) >>> interpolate(query_points) Array([30., 64.], dtype=float32)
Methods
__init__
(points, values[, method, ...])