jax.experimental.pallas.mosaic_gpu.TilingTransform#

class jax.experimental.pallas.mosaic_gpu.TilingTransform(tiling)[source]#

Represents a tiling transformation for memory refs.

A tiling of (X, Y) on an array of shape (M, N) will result in a transformed shape of (M // X, N // Y, X, Y). Ex. A (256, 256) block that is tiled with a tiling of (64, 32) will be tiled as (4, 8, 64, 32).

Parameters:

tiling (tuple[int, ...])

__init__(tiling)#
Parameters:

tiling (tuple[int, ...])

Return type:

None

Methods

__init__(tiling)

batch(leading_rank)

Returns a transform that accepts a ref with the extra leading_rank dims.

to_gpu_transform()

undo(ref)

Attributes

tiling