jax.experimental.custom_dce.custom_dce#

class jax.experimental.custom_dce.custom_dce(fun, *, static_argnums=())[source]#

Customize the DCE behavior of a JAX-transformable function.

JAX uses dead code elimination (DCE) to remove unused computations from a JAX program. This typically works transparently when the program is completely specified by known JAX operations, but opaque kernels like calls to pallas_call() or ffi_call(), for example, may cause problems.

In JAX, DCE is performed when a function is staged out using jax.jit(), so it won’t be applied when running JAX in eager mode. Similarly, the custom_dce decorator requires that both the decorated function and the custom DCE rule be compatible with jit().

This decorator allows users to customize the DCE behavior of a function by defining a custom DCE rule. For a custom_dce wrapped function f(*args), the signature of the DCE rule is dce_rule(used_outs, *args) where used_outs is a Pytree with the same structure as the output of f, and each leaf is is a bool indicating which outputs should be computed. The remaining arguments *args are the original arguments to f. The rule dce_rule should return a Pytree with the same structure as the original output of f, but any unused outputs can be replaced with None.

For example:

>>> @jax.experimental.custom_dce.custom_dce
... def f(x, y):
...   return jnp.sin(x) * y, x * jnp.sin(y)
...
>>> @f.def_dce
... def f_dce_rule(used_outs, x, y):
...   return (
...       jnp.sin(x) * y if used_outs[0] else None,
...       x * jnp.sin(y) if used_outs[1] else None,
...   )

In this example, used_outs is a tuple with two bool values, indicating which outputs are required. The DCE rule only computes the required outputs, replacing the unused outputs with None.

If the static_argnums argument is provided to custom_dce, the indicated arguments are treated as static when the function is traced, and they will be moved to the front when calling the DCE rule. For example, if fun takes 2 arguments fun(x, y), and static_argnums is (1,), then the DCE rule will be called as dce_rule(y, used_outs, x).

Parameters:
__init__(fun, *, static_argnums=())[source]#
Parameters:

Methods

__init__(fun, *[, static_argnums])

def_dce(dce_rule)

Define a custom DCE rule for this function.

Attributes

fun

static_argnums

dce_rule