jax.experimental.pallas.mosaic_gpu.copy_smem_to_gmem#

jax.experimental.pallas.mosaic_gpu.copy_smem_to_gmem(src, dst, predicate=None, *, commit_group=True, reduction_op=None)[source]#

Asynchronously copies a SMEM reference to a GMEM reference.

Parameters:
  • src (_Ref) – The SMEM reference to copy from.

  • dst (_Ref) – The GMEM reference to copy to.

  • predicate (jax.Array | None | None) – A boolean indicating whether the copy should be performed. If None, the copy is always performed.

  • commit_group (bool) – If True, this and any previously uncommitted copies are committed to a group and can be awaited jointly via jax.experimental.mosaic.gpu.wait_smem_to_gmem().

  • reduction_op (mgpu.ReductionOp | None | None) – If set, perform the specified reduction operation when storing to GMEM. For example, using "add" is conceptually equivalent to doing src += dst.

Return type:

None

See also

jax.experimental.mosaic.gpu.wait_smem_to_gmem() jax.experimental.mosaic.gpu.commit_smem()