jax.custom_batching.custom_vmap#
- class jax.custom_batching.custom_vmap(fun)[source]#
Customize the vmap behavior of a JAX-transformable function.
This decorator is used to customize the behavior of a JAX function under the
jax.vmap()
transformation. Acustom_vmap
-decorated function will mostly (see below for caveats) have the same behavior as the underlying function, except when batched usingjax.vmap()
. When batched, the rule defined usingdef_vmap()
will be used.For example:
>>> @jax.custom_batching.custom_vmap ... def f(x, y): ... return x + y ... >>> @f.def_vmap ... def f_vmap_rule(axis_size, in_batched, xs, ys): ... assert all(in_batched) ... assert xs.shape[0] == axis_size ... assert ys.shape[0] == axis_size ... out_batched = True ... return xs * ys, out_batched ... >>> xs = jnp.arange(3) >>> ys = jnp.arange(1, 4) >>> jax.vmap(f)(xs, ys) # prints xs * ys instead of xs + ys Array([0, 2, 6], dtype=int32)
Of note,
custom_vmap
functions do not support reverse-mode autodiff. To customize both vmap and reverse-mode autodiff, combinecustom_vmap
withjax.custom_vjp
. For example:>>> @jax.custom_vjp ... @jax.custom_batching.custom_vmap ... def f(x, y): ... return jnp.sin(x) * y ... >>> @f.def_vmap ... def f_vmap_rule(axis_size, in_batched, xs, ys): ... return jnp.cos(xs) * ys, True ... >>> def f_fwd(x, y): ... return f(x, y), (jnp.cos(x), jnp.sin(x), y) ... >>> def f_bwd(res, g): ... cos_x, sin_x, y = res ... return (cos_x * g * y, sin_x * g) ... >>> f.defvjp(f_fwd, f_bwd) >>> jax.vmap(f)(jnp.zeros(3), jnp.ones(3)) Array([1., 1., 1.], dtype=float32) >>> jax.grad(f)(jnp.zeros(()), jnp.ones(())) Array(1., dtype=float32)
Note that the
jax.custom_vjp
must be on the ouside, wrapping thecustom_vmap
-decorated function.- Parameters:
fun (Callable[..., Any])
Methods
__init__
(fun)def_vmap
(vmap_rule)Define the vmap rule for this custom_vmap function.
Attributes
fun
vmap_rule