jax.experimental.pallas.mosaic_gpu.copy_smem_to_gmem#
- jax.experimental.pallas.mosaic_gpu.copy_smem_to_gmem(src, dst, predicate=None, *, commit_group=True, reduction_op=None)[source]#
Asynchronously copies a SMEM reference to a GMEM reference.
- Parameters:
src (_Ref) – The SMEM reference to copy from.
dst (_Ref) – The GMEM reference to copy to.
predicate (jax.Array | None | None) – A boolean indicating whether the copy should be performed. If
None
, the copy is always performed.commit_group (bool) – If
True
, this and any previously uncommitted copies are committed to a group and can be awaited jointly viajax.experimental.mosaic.gpu.wait_smem_to_gmem()
.reduction_op (mgpu.ReductionOp | None | None) – If set, perform the specified reduction operation when storing to GMEM. For example, using
"add"
is conceptually equivalent to doingsrc += dst
.
- Return type:
None
See also
jax.experimental.mosaic.gpu.wait_smem_to_gmem()
jax.experimental.mosaic.gpu.commit_smem()