jax.nn.scaled_dot_general#

jax.nn.scaled_dot_general(lhs, rhs, dimension_numbers, preferred_element_type=<class 'jax.numpy.float32'>, configs=None, implementation=None)[source]#

Scaled dot general operation.

Performs a generalized dot product with block-scaled quantization on the lhs and rhs inputs. This operation extends lax.dot_general to support user-defined scaling configurations.

Essentially, the operation follows:

a, a_scales = quantize(lhs, configs[0])
b, b_scales = quantize(rhs, configs[1])
c = jax.nn.scaled_matmul(a, b, a_scales, b_scales)
Parameters:
  • lhs (ArrayLike) – Input array.

  • rhs (ArrayLike) – Input array.

  • dimension_numbers (DotDimensionNumbers) – A tuple of two tuples specifying the contraction and batch dimensions: ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims)).

  • preferred_element_type (DTypeLike, optional) – Output data type of the dot product. Defaults to jnp.float32. Other valid types include jnp.bfloat16 and jnp.float16.

  • configs (list of BlockScaleConfig, optional) – Scaling configurations for lhs, rhs, and gradients. Users can obtain valid configurations via jax.nn.get_scaled_dot_general_config. Currently, nvfp4 and mxfp8 are supported. If None, falls back to lax.dot_general.

  • implementation (Literal['cudnn'] | None | None) – str (Deprecated) Backend selector, now ignored. The system chooses the backend automatically. Scheduled for removal in future releases.

Returns:

The resulting tensor, with batch dimensions first, followed by non-contracting/non-batch dimensions of lhs, and then those of rhs.

Return type:

Array

See also

Notes

  • Unlike nn.scaled_matmul, which assumes quantized low-precision inputs with explicit scaling factors, this operator takes high-precision inputs, applies quantization internally, and handles the backward pass.

Examples

Creating config for mxfp8:

>>> configs = [jax.nn.get_scaled_dot_general_config('mxfp8')] * 3

Creating config for nvfp4:

>>> global_scale = jnp.array([0.5], jnp.float32)
>>> configs = [jax.nn.get_scaled_dot_general_config('nvfp4', global_scale)] * 3

Using scaled_dot_general with the configs:

>>> import functools
>>> scaled_dot_general_fn = functools.partial(jax.nn.scaled_dot_general, configs=configs)
>>> lhs = jax.random.normal(jax.random.PRNGKey(1), (3, 128, 64))
>>> rhs = jax.random.normal(jax.random.PRNGKey(2), (3, 128, 64))
>>> out = scaled_dot_general_fn(lhs, rhs, (((2,), (2,)), ((0,), (0,))))