jax.experimental.pallas.mosaic_gpu.SwizzleTransform#

class jax.experimental.pallas.mosaic_gpu.SwizzleTransform(swizzle: 'int')[source]#
Parameters:

swizzle (int)

__init__(swizzle)#
Parameters:

swizzle (int)

Return type:

None

Methods

__init__(swizzle)

batch(leading_rank)

Returns a transform that accepts a ref with the extra leading_rank dims.

to_gpu_transform()

undo(ref)

undo_to_gpu_transform()

Attributes

swizzle