jax.nn.scaled_matmul#
- jax.nn.scaled_matmul(lhs, rhs, lhs_scales, rhs_scales, preferred_element_type=<class 'jax.numpy.float32'>)[source]#
Scaled matrix multiplication function.
Performs block-scaled matmul of a and b using a_scales and b_scales. The last dim is the contracting dim, and block size is inferred.
Mathematically, this operation is equivalent to:
a_block_size = a.shape[-1] // a_scales.shape[-1] b_block_size = b.shape[-1] // b_scales.shape[-1] a_scaled = a * jnp.repeat(a_scales, a_block_size, axis=-1) b_scaled = b * jnp.repeat(b_scales, b_block_size, axis=-1) jnp.einsum('BMK,BNK->BMN', a_scaled, b_scaled)
- Parameters:
- Returns:
Array of shape (B, M, N).
- Return type:
Notes
We currently do not support user-defined precision for customizing the compute data type. It is fixed to jnp.float32.
Block size is inferred as K // K_a for a and K // K_b for b.
To use cuDNN with Nvidia Blackwell GPUs, inputs must match:
# mxfp8 a, b: jnp.float8_e4m3fn | jnp.float8_e5m2 a_scales, b_scales: jnp.float8_e8m0fnu block_size: 32 # nvfp4 a, b: jnp.float4_e2m1fn a_scales, b_scales: jnp.float8_e4m3fn block_size: 16
Examples
Basic case:
>>> a = jnp.array([1, 2, 3]).reshape((1, 1, 3)) >>> b = jnp.array([4, 5, 6]).reshape((1, 1, 3)) >>> a_scales = jnp.array([0.5]).reshape((1, 1, 1)) >>> b_scales = jnp.array([0.5]).reshape((1, 1, 1)) >>> scaled_matmul(a, b, a_scales, b_scales) Array([[[8.]]], dtype=float32)
Using fused cuDNN call on Blackwell GPUs:
>>> dtype = jnp.float8_e4m3fn >>> a = jax.random.normal(jax.random.PRNGKey(1), (3, 128, 64), dtype=dtype) >>> b = jax.random.normal(jax.random.PRNGKey(2), (3, 128, 64), dtype=dtype) >>> a_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu) >>> b_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu) >>> scaled_matmul(a, b, a_scales, b_scales)