jax.experimental.pallas.debug_print#
- jax.experimental.pallas.debug_print(fmt, *args)[source]#
Prints values from inside a Pallas kernel.
- Parameters:
fmt (str) –
A format string to be included in the output. The restrictions on the format string depend on the backend:
On GPU, when using Triton,
fmt
must not contain any placeholders ({...}
), since it is always printed before any of the values.On GPU, when using the experimental Mosaic GPU backend,
fmt
must contain a placeholder for each value to be printed. Format specs and conversions are not supported. All values must be scalars.On TPU, if all inputs are scalars: If
fmt
contains placeholders, all values must be 32-bit integers. If there are no placeholders, the values are printed after the format string.On TPU, if the input is a single vector, the vector is printed after the format string. The format string must end with a single placeholder
{}
.
*args (jax.typing.ArrayLike) – The values to print.