jax.experimental.pallas.mosaic_gpu.GPUCompilerParams#
- class jax.experimental.pallas.mosaic_gpu.GPUCompilerParams(*, approx_math=False, dimension_semantics=None, max_concurrent_steps=1, delay_release=0, profile_space=0, profile_dir='', lowering_semantics=LoweringSemantics.Lane)[source]#
Mosaic GPU compiler parameters.
- Parameters:
- approx_math#
If True, the compiler is allowed to use approximate implementations of some math operations, e.g.
exp
. Defaults to False.- Type:
- dimension_semantics#
A list of dimension semantics for each grid dimension of the kernel. Either “parallel” for dimensions that can execute in any order, or “sequential” for dimensions that must be executed sequentially.
- Type:
Sequence[DimensionSemantics] | None
- max_concurrent_steps#
The maximum number of sequential stages that are active concurrently. Defaults to 1.
- Type:
- delay_release#
The number of steps to wait before reusing the input/output references. Defaults to 0, and must be strictly smaller than max_concurrent_steps. Generally, you’ll want to set it to 1 if you don’t await the WGMMA in the body.
- Type:
- profile_space#
The number of profiler events that can be collected in a single invocation. It is undefined behavior if a thread collects more events than this.
- Type:
- __init__(*, approx_math=False, dimension_semantics=None, max_concurrent_steps=1, delay_release=0, profile_space=0, profile_dir='', lowering_semantics=LoweringSemantics.Lane)#
Methods
__init__
(*[, approx_math, ...])Attributes
BACKEND
lowering_semantics