jax.experimental.pallas.mosaic_gpu.emit_pipeline#

jax.experimental.pallas.mosaic_gpu.emit_pipeline(body, *, grid, in_specs=(), out_specs=(), max_concurrent_steps=1, delay_release=0)[source]#

Creates a function to emit a manual pipeline within a Pallas kernel.

Parameters:
  • body (Callable[..., None]) – The pipeline body, called with the indices for the current step, the input refs, followed by the output refs.

  • grid (pallas_core.TupleGrid) – The grid to use for the pipeline.

  • in_specs (Sequence[pallas_core.BlockSpec]) – The block specs for the inputs.

  • out_specs (Sequence[pallas_core.BlockSpec]) – The block specs for the outputs.

  • max_concurrent_steps (int) – The maximum number of sequential stages that are active concurrently. Defaults to 1.

  • delay_release (int) – The number of steps to wait before reusing the input/output references. Defaults to 0, and must be strictly smaller than max_concurrent_steps. Generally, you’ll want to set it to 1 if you don’t await the WGMMA in the body.