jax.custom_batching.custom_vmap.def_vmap#

custom_vmap.def_vmap(vmap_rule)[source]#

Define the vmap rule for this custom_vmap function.

Parameters:

vmap_rule (Callable[..., tuple[Any, Any]]) – A function that implements the vmap rule. This function should accept the following arguments: (1) an integer axis_size as its first argument, (2) a pytree of booleans with the same structure as the inputs to the function, specifying whether each argument is batched, and (3) the batched arguments. It should return a tuple of the batched output and a pytree of booleans with the same structure as the output, specifying whether each output element is batched. See the documentation for jax.custom_batching.custom_vmap() for some examples.

Returns:

This method passes the rule through, returning vmap_rule unchanged.

Return type:

Callable[…, tuple[Any, Any]]