jax.experimental.pallas.debug_print#

jax.experimental.pallas.debug_print(fmt, *args)[source]#

Prints values from inside a Pallas kernel.

Parameters:
  • fmt (str) –

    A format string to be included in the output. The restrictions on the format string depend on the backend:

    • On GPU, when using Triton, fmt must not contain any placeholders ({...}), since it is always printed before any of the values.

    • On GPU, when using the experimental Mosaic GPU backend, fmt must contain a placeholder for each value to be printed. Format specs and conversions are not supported. All values must be scalars.

    • On TPU, if all inputs are scalars: If fmt contains placeholders, all values must be 32-bit integers. If there are no placeholders, the values are printed after the format string.

    • On TPU, if the input is a single vector, the vector is printed after the format string. The format string must end with a single placeholder {}.

  • *args (jax.typing.ArrayLike) – The values to print.