jax.test_util.check_vjp#
- jax.test_util.check_vjp(f, f_vjp, args, atol=None, rtol=None, eps=0.0001, err_msg='')[source]#
Check a VJP from automatic differentiation against finite differences.
Gradients are only checked in a single randomly chosen direction, which ensures that the finite difference calculation does not become prohibitively expensive even for large input/output spaces.
- Parameters:
f – function to check at
f(*args)
.f_vjp – function that calculates
jax.vjp
applied tof
. Typically this should befunctools.partial(jax.jvp, f))
.args – tuple of argument values.
atol – absolute tolerance for gradient equality.
rtol – relative tolerance for gradient equality.
eps – step size used for finite differences.
err_msg – additional error message to include if checks fail.
- Raises:
AssertionError – if gradients do not match.