jax.experimental.pallas.mosaic_gpu.wgmma#
- jax.experimental.pallas.mosaic_gpu.wgmma(acc, a, b)[source]#
Performs an asynchronous warp group matmul-accumulate on the given references.
Conceptually, this is equivalent to doing
acc[...] += a[...] @ b[...]
, except that the computation is performed asynchronously.- Parameters:
acc (gpu_core.WGMMAAbstractAccumulatorRef) – The accumulator reference. Needs to be allocated via
jax.experimental.pallas.run_scoped()
called with ajax.experimental.pallas.mosaic_gpu.WGMMAAccumulatorRef()
.a – The left hand side operand reference.
b – The right hand side operand reference.
- Return type:
None