jax.experimental.pallas.mosaic_gpu.set_max_registers#

jax.experimental.pallas.mosaic_gpu.set_max_registers(n, *, action)[source]#

Sets the maximum number of registers owned by a warp.

Parameters:
  • n (int)

  • action (Literal['increase', 'decrease'])