jax.test_util.check_vjp#

jax.test_util.check_vjp(f, f_vjp, args, atol=None, rtol=None, eps=0.0001, err_msg='')[source]#

Check a VJP from automatic differentiation against finite differences.

Gradients are only checked in a single randomly chosen direction, which ensures that the finite difference calculation does not become prohibitively expensive even for large input/output spaces.

Parameters:
  • f – function to check at f(*args).

  • f_vjp – function that calculates jax.vjp applied to f. Typically this should be functools.partial(jax.jvp, f)).

  • args – tuple of argument values.

  • atol – absolute tolerance for gradient equality.

  • rtol – relative tolerance for gradient equality.

  • eps – step size used for finite differences.

  • err_msg – additional error message to include if checks fail.

Raises:

AssertionError – if gradients do not match.