jax.experimental.custom_dce.custom_dce.def_dce#

custom_dce.def_dce(dce_rule)[source]#

Define a custom DCE rule for this function.

Parameters:

dce_rule (Callable[[...], Any]) – A function that takes (a) any arguments indicated as static using static_argnums, (b) a Pytree of bool values (used_outs) indicating which outputs should be computed, and (c) the rest of the (non-static) arguments to the original function. The rule should return a Pytree with with the same structure as the output of the original function, but any unused outputs (as indicated by used_outs) can be replaced with None.

Return type:

Callable[[…], Any]