jax.scipy.interpolate.RegularGridInterpolator#

class jax.scipy.interpolate.RegularGridInterpolator(points, values, method='linear', bounds_error=False, fill_value=nan)[source]#

Interpolate points on a regular rectangular grid.

JAX implementation of scipy.interpolate.RegularGridInterpolator().

Parameters:
  • points – length-N sequence of arrays specifying the grid coordinates.

  • values – N-dimensional array specifying the grid values.

  • method – interpolation method, either "linear" or "nearest".

  • bounds_error – not implemented by JAX

  • fill_value – value returned for points outside the grid, defaults to NaN.

Returns:

callable interpolation object.

Return type:

interpolator

Examples

>>> points = (jnp.array([1, 2, 3]), jnp.array([4, 5, 6]))
>>> values = jnp.array([[10, 20, 30], [40, 50, 60], [70, 80, 90]])
>>> interpolate = RegularGridInterpolator(points, values, method='linear')
>>> query_points = jnp.array([[1.5, 4.5], [2.2, 5.8]])
>>> interpolate(query_points)
Array([30., 64.], dtype=float32)
__init__(points, values, method='linear', bounds_error=False, fill_value=nan)[source]#

Methods

__init__(points, values[, method, ...])