jax.experimental.pallas.mosaic_gpu.wgmma#

jax.experimental.pallas.mosaic_gpu.wgmma(acc, a, b)[source]#

Performs an asynchronous warp group matmul-accumulate on the given references.

Conceptually, this is equivalent to doing acc[...] += a[...] @ b[...], except that the computation is performed asynchronously.

Parameters:
Return type:

None