jax.experimental.pallas.mosaic_gpu.copy_gmem_to_smem#

jax.experimental.pallas.mosaic_gpu.copy_gmem_to_smem(src, dst, barrier, *, collective_axes=None)[source]#

Asynchronously copies a GMEM reference to a SMEM reference.

See also

jax.experimental.mosaic.gpu.barrier_arrive() jax.experimental.mosaic.gpu.barrier_wait()

Parameters:
  • src (_Ref)

  • dst (_Ref)

  • barrier (_Ref)

  • collective_axes (str | tuple[str, ...] | None | None)

Return type:

None