jax.nn.get_scaled_dot_general_config#

jax.nn.get_scaled_dot_general_config(mode, global_scale=None)[source]#

Get quantization configs for scaled_dot_general.

Create quantization configs for the jax.nn.scaled_dot_general.

See also

Parameters:
  • mode (Literal['nvfp4', 'mxfp8'])

  • global_scale (Array | None | None)