jax.experimental.pallas.mosaic_gpu.Layout#

class jax.experimental.pallas.mosaic_gpu.Layout(value)[source]#

An enumeration.

__init__()#

Methods

to_mgpu(*args, **kwargs)

Attributes

WGMMA

[m, n] matrix, where m % 64 == 0 == n % 8.

WGMMA_ROW

[m] matrix, where m % 64 == 0.

WGMMA_COL

[n] matrix, where n % 8 == 0.

WGMMA_TRANSPOSED

WG_SPLAT

WG_STRIDED

TCGEN05