jax.experimental.pallas.mosaic_gpu.WGMMAAccumulatorRef#

class jax.experimental.pallas.mosaic_gpu.WGMMAAccumulatorRef(shape: 'tuple[int, int]', dtype: 'jnp.dtype' = <class 'jax.numpy.float32'>, _init: 'Any' = <jax._src.state.types.Uninitialized object at 0x780fca172920>)[source]#
Parameters:
  • shape (tuple[int, int])

  • dtype (jnp.dtype)

  • _init (Any)

__init__(shape, dtype=<class 'jax.numpy.float32'>, _init=<jax._src.state.types.Uninitialized object>)#
Parameters:
  • shape (tuple[int, int])

  • dtype (jnp.dtype)

  • _init (Any)

Return type:

None

Methods

__init__(shape[, dtype, _init])

get_ref_aval()

init(array)

Attributes

shape