jax.experimental.custom_dce.custom_dce#
- class jax.experimental.custom_dce.custom_dce(fun, *, static_argnums=())[source]#
Customize the DCE behavior of a JAX-transformable function.
JAX uses dead code elimination (DCE) to remove unused computations from a JAX program. This typically works transparently when the program is completely specified by known JAX operations, but opaque kernels like calls to
pallas_call()
orffi_call()
, for example, may cause problems.In JAX, DCE is performed when a function is staged out using
jax.jit()
, so it won’t be applied when running JAX in eager mode. Similarly, thecustom_dce
decorator requires that both the decorated function and the custom DCE rule be compatible withjit()
.This decorator allows users to customize the DCE behavior of a function by defining a custom DCE rule. For a
custom_dce
wrapped functionf(*args)
, the signature of the DCE rule isdce_rule(used_outs, *args)
whereused_outs
is a Pytree with the same structure as the output off
, and each leaf is is abool
indicating which outputs should be computed. The remaining arguments*args
are the original arguments tof
. The ruledce_rule
should return a Pytree with the same structure as the original output off
, but any unused outputs can be replaced withNone
.For example:
>>> @jax.experimental.custom_dce.custom_dce ... def f(x, y): ... return jnp.sin(x) * y, x * jnp.sin(y) ... >>> @f.def_dce ... def f_dce_rule(used_outs, x, y): ... return ( ... jnp.sin(x) * y if used_outs[0] else None, ... x * jnp.sin(y) if used_outs[1] else None, ... )
In this example,
used_outs
is atuple
with twobool
values, indicating which outputs are required. The DCE rule only computes the required outputs, replacing the unused outputs withNone
.If the
static_argnums
argument is provided tocustom_dce
, the indicated arguments are treated as static when the function is traced, and they will be moved to the front when calling the DCE rule. For example, iffun
takes 2 argumentsfun(x, y)
, andstatic_argnums
is(1,)
, then the DCE rule will be called asdce_rule(y, used_outs, x)
.Methods
__init__
(fun, *[, static_argnums])def_dce
(dce_rule)Define a custom DCE rule for this function.
Attributes
fun
static_argnums
dce_rule