jax.experimental.custom_dce.custom_dce.def_dce#
- custom_dce.def_dce(dce_rule)[source]#
Define a custom DCE rule for this function.
- Parameters:
dce_rule (Callable[[...], Any]) – A function that takes (a) any arguments indicated as static using
static_argnums
, (b) a Pytree ofbool
values (used_outs
) indicating which outputs should be computed, and (c) the rest of the (non-static) arguments to the original function. The rule should return a Pytree with with the same structure as the output of the original function, but any unused outputs (as indicated byused_outs
) can be replaced withNone
.- Return type: