jax.experimental.pallas.mosaic_gpu.wait_smem_to_gmem#

jax.experimental.pallas.mosaic_gpu.wait_smem_to_gmem(n, wait_read_only=False)[source]#

Waits until there are no more than n SMEM->GMEM copies in flight.

Parameters:
  • n (int) – The maximum number of copies in flight to wait for.

  • wait_read_only (bool) – If True, wait for the in flight copies to finish reading from SMEM. The writes to GMEM are not waited for.

Return type:

None