jax.experimental.pallas.mosaic_gpu.TilingTransform#
- class jax.experimental.pallas.mosaic_gpu.TilingTransform(tiling)[source]#
Represents a tiling transformation for memory refs.
A tiling of (X, Y) on an array of shape (M, N) will result in a transformed shape of (M // X, N // Y, X, Y). Ex. A (256, 256) block that is tiled with a tiling of (64, 32) will be tiled as (4, 8, 64, 32).
Methods
__init__
(tiling)batch
(leading_rank)Returns a transform that accepts a ref with the extra leading_rank dims.
to_gpu_transform
()undo
(ref)Attributes
tiling