jax.custom_batching.sequential_vmap#

jax.custom_batching.sequential_vmap(f)[source]#

A special case of custom_vmap that uses a loop.

A function decorated with sequential_vmap will be called sequentially within a loop when batched. This is useful for functions that don’t natively support batch dimensions.

For example:

>>> @jax.custom_batching.sequential_vmap
... def f(x):
...   jax.debug.print("{}", x)
...   return x + 1
...
>>> jax.vmap(f)(jnp.arange(3))
0
1
2
Array([1, 2, 3], dtype=int32)

Where the print statements demonstrate that this vmap() is being generated using a loop.

See the documentation for custom_vmap for more details.