jax.custom_vjp.defvjp#
- custom_vjp.defvjp(fwd, bwd, symbolic_zeros=False, optimize_remat=False)[source]#
Define a custom VJP rule for the function represented by this instance.
- Parameters:
fwd (Callable[..., tuple[ReturnValue, Any]]) – a Python callable representing the forward pass of the custom VJP rule. When there are no
nondiff_argnums
, thefwd
function has the same input signature as the underlying primal function. It should return as output a pair, where the first element represents the primal output and the second element represents any “residual” values to store from the forward pass for use on the backward pass by the functionbwd
. Input arguments and elements of the output pair may be arrays or nested tuples/lists/dicts thereof.bwd (Callable[..., tuple[Any, ...]]) – a Python callable representing the backward pass of the custom VJP rule. When there are no
nondiff_argnums
, thebwd
function takes two arguments, where the first is the “residual” values produced on the forward pass byfwd
, and the second is the output cotangent with the same structure as the primal function output. The output ofbwd
must be a tuple of length equal to the number of arguments of the primal function, and the tuple elements may be arrays or nested tuples/lists/dicts thereof so as to match the structure of the primal input arguments.symbolic_zeros (bool) –
boolean, determining whether to indicate symbolic zeros to the
fwd
andbwd
rules. Enabling this option allows custom derivative rules to detect when certain inputs, and when certain output cotangents, are not involved in differentiation. IfTrue
:fwd
must accept, in place of each leaf valuex
in the pytree comprising an argument to the original function, an object (of typejax.custom_derivatives.CustomVJPPrimal
) with two attributes instead:value
andperturbed
. Thevalue
field is the original primal argument, andperturbed
is a boolean. Theperturbed
bit indicates whether the argument is involved in differentiation (i.e., if it isFalse
, then the corresponding Jacobian “column” is zero).bwd
will be passed objects representing static symbolic zeros in its cotangent argument in correspondence with unperturbed values; otherwise, only standard JAX types (e.g. array-likes) are passed.
Setting this option to
True
allows these rules to detect whether certain inputs and outputs are not involved in differentiation, but at the cost of special handling. For instance:The signature of
fwd
changes, and the objects it is passed cannot be output from the rule directly.The
bwd
rule is passed objects that are not entirely array-like, and that cannot be passed to mostjax.numpy
functions.Any custom pytree nodes involved in the primal function’s arguments must accept, in their unflattening functions, the two-field record objects that are given as input leaves to the
fwd
rule.
Default
False
.optimize_remat (bool) – boolean, an experimental flag to enable an automatic optimization when this function is used under
jax.remat()
. This will be most useful when thefwd
rule is an opaque call such as a Pallas kernel or a custom call. DefaultFalse
.
- Returns:
None.
- Return type:
None
Examples
>>> @jax.custom_vjp ... def f(x, y): ... return jnp.sin(x) * y ... >>> def f_fwd(x, y): ... return f(x, y), (jnp.cos(x), jnp.sin(x), y) ... >>> def f_bwd(res, g): ... cos_x, sin_x, y = res ... return (cos_x * g * y, sin_x * g) ... >>> f.defvjp(f_fwd, f_bwd)
>>> x = jnp.float32(1.0) >>> y = jnp.float32(2.0) >>> with jnp.printoptions(precision=2): ... print(jax.value_and_grad(f)(x, y)) (Array(1.68, dtype=float32), Array(1.08, dtype=float32))