jax.experimental.pallas.mosaic_gpu.TransposeTransform# class jax.experimental.pallas.mosaic_gpu.TransposeTransform(permutation)[source]# Transpose a tiled memref. Parameters: permutation (tuple[int, ...]) __init__(permutation)# Parameters: permutation (tuple[int, ...]) Return type: None Methods __init__(permutation) batch(leading_rank) Returns a transform that accepts a ref with the extra leading_rank dims. to_gpu_transform() undo(ref) Attributes permutation