jax.experimental.pallas.mosaic_gpu.GPUBlockSpec#

class jax.experimental.pallas.mosaic_gpu.GPUBlockSpec(block_shape: 'Sequence[BlockDim | int | None] | None' = None, index_map: 'Callable[..., Any] | None' = None, indexing_mode: 'Any | None' = None, pipeline_mode: 'Buffered | None' = None, transforms: 'Sequence[MemoryRefTransform]' = (), *, memory_space: 'Any | None' = None)[source]#
Parameters:
  • block_shape (Sequence[BlockDim | int | None] | None)

  • index_map (Callable[..., Any] | None)

  • indexing_mode (Any | None)

  • pipeline_mode (Buffered | None)

  • transforms (Sequence[MemoryRefTransform])

  • memory_space (Any | None)

__init__(block_shape=None, index_map=None, indexing_mode=None, pipeline_mode=None, transforms=(), *, memory_space=None)#
Parameters:
  • block_shape (Sequence[BlockDim | int | None] | None | None)

  • index_map (Callable[..., Any] | None | None)

  • indexing_mode (Any | None | None)

  • pipeline_mode (Buffered | None | None)

  • transforms (Sequence[MemoryRefTransform])

  • memory_space (Any | None | None)

Return type:

None

Methods

__init__([block_shape, index_map, ...])

replace(**changes)

Return a new object replacing specified fields with new values.

to_block_mapping(origin, array_aval, *, ...)

Attributes

block_shape

index_map

indexing_mode

memory_space

pipeline_mode

transforms