jax.experimental.pallas.mosaic_gpu.Barrier#

class jax.experimental.pallas.mosaic_gpu.Barrier(num_arrivals, num_barriers=1, *, for_tensor_core=False)[source]#

Describes a barrier Ref.

Parameters:
  • num_arrivals (int)

  • num_barriers (int)

  • for_tensor_core (bool)

num_arrivals#

The number of arrivals that will be recorded by this barrier.

Type:

int

num_barriers#

The number of barriers that will be created. Individual barriers can be accessed by indexing into the barrier Ref.

Type:

int

for_tensor_core#

Whether this barrier is used for synchronizing with the tensor core. This should be set to True when waiting on Blackwell (TC Gen 5) asynchoronous matmul instructions.

Type:

bool

__init__(num_arrivals, num_barriers=1, *, for_tensor_core=False)#
Parameters:
  • num_arrivals (int)

  • num_barriers (int)

  • for_tensor_core (bool)

Return type:

None

Methods

__init__(num_arrivals[, num_barriers, ...])

get_ref_aval()

Attributes

for_tensor_core

num_barriers

num_arrivals