jax.test_util.check_grads#

jax.test_util.check_grads(f, args, order, modes=('fwd', 'rev'), atol=None, rtol=None, eps=None)[source]#

Check gradients from automatic differentiation against finite differences.

Gradients are only checked in a single randomly chosen direction, which ensures that the finite difference calculation does not become prohibitively expensive even for large input/output spaces.

Parameters:
  • f – function to check at f(*args).

  • args – tuple of argument values.

  • order – forward and backwards gradients up to this order are checked.

  • modes – lists of gradient modes to check (‘fwd’ and/or ‘rev’).

  • atol – absolute tolerance for gradient equality.

  • rtol – relative tolerance for gradient equality.

  • eps – step size used for finite differences.

Raises:

AssertionError – if gradients do not match.