jax.experimental.pallas.mosaic_gpu.TransposeTransform#

class jax.experimental.pallas.mosaic_gpu.TransposeTransform(permutation)[source]#

Transpose a tiled memref.

Parameters:

permutation (tuple[int, ...])

__init__(permutation)#
Parameters:

permutation (tuple[int, ...])

Return type:

None

Methods

__init__(permutation)

batch(leading_rank)

Returns a transform that accepts a ref with the extra leading_rank dims.

to_gpu_transform()

undo(ref)

Attributes

permutation