jax.experimental.pallas.mosaic_gpu.wgmma_wait#

jax.experimental.pallas.mosaic_gpu.wgmma_wait(n)[source]#

Waits until there is no more than n WGMMA operations in flight.

Parameters:

n (int)