jax.experimental.pallas.mosaic_gpu.emit_pipeline#
- jax.experimental.pallas.mosaic_gpu.emit_pipeline(body, *, grid, in_specs=(), out_specs=(), max_concurrent_steps=1, delay_release=0)[source]#
Creates a function to emit a manual pipeline within a Pallas kernel.
- Parameters:
body (Callable[..., None]) – The pipeline body, called with the indices for the current step, the input refs, followed by the output refs.
grid (pallas_core.TupleGrid) – The grid to use for the pipeline.
in_specs (Sequence[pallas_core.BlockSpec]) – The block specs for the inputs.
out_specs (Sequence[pallas_core.BlockSpec]) – The block specs for the outputs.
max_concurrent_steps (int) – The maximum number of sequential stages that are active concurrently. Defaults to 1.
delay_release (int) – The number of steps to wait before reusing the input/output references. Defaults to 0, and must be strictly smaller than
max_concurrent_steps
. Generally, you’ll want to set it to 1 if you don’t await the WGMMA in the body.