jax.custom_batching.custom_vmap.def_vmap#
- custom_vmap.def_vmap(vmap_rule)[source]#
Define the vmap rule for this custom_vmap function.
- Parameters:
vmap_rule (Callable[..., tuple[Any, Any]]) – A function that implements the vmap rule. This function should accept the following arguments: (1) an integer
axis_size
as its first argument, (2) a pytree of booleans with the same structure as the inputs to the function, specifying whether each argument is batched, and (3) the batched arguments. It should return a tuple of the batched output and a pytree of booleans with the same structure as the output, specifying whether each output element is batched. See the documentation forjax.custom_batching.custom_vmap()
for some examples.- Returns:
This method passes the rule through, returning
vmap_rule
unchanged.- Return type:
Callable[…, tuple[Any, Any]]