jax.nn.get_scaled_dot_general_config#
- jax.nn.get_scaled_dot_general_config(mode, global_scale=None)[source]#
Get quantization configs for scaled_dot_general.
Create quantization configs for the jax.nn.scaled_dot_general.
See also
jax.nn.scaled_dot_general()
: Scaled dot general function.