jax.export.register_namedtuple_serialization#

jax.export.register_namedtuple_serialization(nodetype, *, serialized_name)[source]#

Registers a namedtuple for serialization and deserialization.

JAX has native PyTree support for collections.namedtuple, and does not require a call to jax.tree_util.register_pytree_node. However, if you want to serialize functions that have inputs of outputs of a namedtuple type, you must register that type for serialization.

Parameters:
  • nodetype (type[T]) – the type whose PyTree nodes we want to serialize. It is an error to attempt to register multiple serializations for a nodetype. On deserialization, this type must have the same set of keys that were present during serialization.

  • serialized_name (str) – a string that will be present in the serialization and will be used to look up the registration during deserialization. It is an error to attempt to register multiple serializations for a serialized_name.

Returns:

the same type passed as nodetype, so that this function can be used as a class decorator.

Return type:

type[T]