jax.numpy.gradient#

jax.numpy.gradient(f, *varargs, axis=None, edge_order=None)[source]#

Compute the numerical gradient of a sampled function.

JAX implementation of numpy.gradient().

The gradient in jnp.gradient is computed using second-order finite differences across the array of sampled function values. This should not be confused with jax.grad(), which computes a precise gradient of a callable function via automatic differentiation.

Parameters:
  • f (ArrayLike) – N-dimensional array of function values.

  • varargs (ArrayLike) –

    optional list of scalars or arrays specifying spacing of function evaluations. Options are:

    • not specified: unit spacing in all dimensions.

    • a single scalar: constant spacing in all dimensions.

    • N values: specify different spacing in each dimension:

      • scalar values indicate constant spacing in that dimension.

      • array values must match the length of the corresponding dimension, and specify the coordinates at which f is evaluated.

  • edge_order (int | None) – not implemented in JAX

  • axis (int | Sequence[int] | None) – integer or tuple of integers specifying the axis along which to compute the gradient. If None (default) calculates the gradient along all axes.

Returns:

an array or tuple of arrays containing the numerical gradient along each specified axis.

Return type:

Array | list[Array]

See also

  • jax.grad(): automatic differentiation of a function with a single output.

Examples

Comparing numerical and automatic differentiation of a simple function:

>>> def f(x):
...   return jnp.sin(x) * jnp.exp(-x / 4)
...
>>> def gradf_exact(x):
...   # exact analytical gradient of f(x)
...   return -f(x) / 4 + jnp.cos(x) * jnp.exp(-x / 4)
...
>>> x = jnp.linspace(0, 5, 10)
>>> with jnp.printoptions(precision=2, suppress=True):
...   print("numerical gradient:", jnp.gradient(f(x), x))
...   print("automatic gradient:", jax.vmap(jax.grad(f))(x))
...   print("exact gradient:    ", gradf_exact(x))
...
numerical gradient: [ 0.83  0.61  0.18 -0.2  -0.43 -0.49 -0.39 -0.21 -0.02  0.08]
automatic gradient: [ 1.    0.62  0.17 -0.23 -0.46 -0.51 -0.41 -0.21 -0.01  0.15]
exact gradient:     [ 1.    0.62  0.17 -0.23 -0.46 -0.51 -0.41 -0.21 -0.01  0.15]

Notice that, as expected, the numerical gradient has some approximation error compared to the automatic gradient computed via jax.grad().