jax.experimental.pallas.mosaic_gpu.WGMMAAccumulatorRef#
- class jax.experimental.pallas.mosaic_gpu.WGMMAAccumulatorRef(shape: 'tuple[int, int]', dtype: 'jnp.dtype' = <class 'jax.numpy.float32'>, _init: 'Any' = <jax._src.state.types.Uninitialized object at 0x780fca172920>)[source]#
-
- __init__(shape, dtype=<class 'jax.numpy.float32'>, _init=<jax._src.state.types.Uninitialized object>)#
Methods
__init__
(shape[, dtype, _init])get_ref_aval
()init
(array)Attributes
shape