jax.experimental.pallas.run_scoped#
- jax.experimental.pallas.run_scoped(f, *types, collective_axes=(), **kw_types)[source]#
Calls the function with allocated references and returns the result.
The positional and keyword arguments describe which reference types to allocate for each argument. Each backend has its own set of reference types in addition to
jax.experimental.pallas.MemoryRef
.When collective_axes is specified, the same allocation will be returned for all programs that only differ in their program ids along the collective axes. It is an error not to call the same run_scoped in all programs along that axis.
- Parameters:
f (Callable[..., Any])
types (Any)
collective_axes (Hashable | tuple[Hashable, ...])
kw_types (Any)
- Return type:
Any