jax.experimental.pallas.mosaic_gpu.Layout#
- class jax.experimental.pallas.mosaic_gpu.Layout(value)[source]#
An enumeration.
- __init__()#
Methods
to_mgpu
(*args, **kwargs)Attributes
WGMMA
[m, n] matrix, where m % 64 == 0 == n % 8.
WGMMA_ROW
[m] matrix, where m % 64 == 0.
WGMMA_COL
[n] matrix, where n % 8 == 0.
WGMMA_TRANSPOSED
WG_SPLAT
WG_STRIDED
TCGEN05