jax.custom_batching.sequential_vmap#
- jax.custom_batching.sequential_vmap(f)[source]#
A special case of
custom_vmap
that uses a loop.A function decorated with
sequential_vmap
will be called sequentially within a loop when batched. This is useful for functions that don’t natively support batch dimensions.For example:
>>> @jax.custom_batching.sequential_vmap ... def f(x): ... jax.debug.print("{}", x) ... return x + 1 ... >>> jax.vmap(f)(jnp.arange(3)) 0 1 2 Array([1, 2, 3], dtype=int32)
Where the print statements demonstrate that this
vmap()
is being generated using a loop.See the documentation for
custom_vmap
for more details.