jax.export.register_namedtuple_serialization#
- jax.export.register_namedtuple_serialization(nodetype, *, serialized_name)[source]#
Registers a namedtuple for serialization and deserialization.
JAX has native PyTree support for collections.namedtuple, and does not require a call to jax.tree_util.register_pytree_node. However, if you want to serialize functions that have inputs of outputs of a namedtuple type, you must register that type for serialization.
- Parameters:
nodetype (type[T]) – the type whose PyTree nodes we want to serialize. It is an error to attempt to register multiple serializations for a nodetype. On deserialization, this type must have the same set of keys that were present during serialization.
serialized_name (str) – a string that will be present in the serialization and will be used to look up the registration during deserialization. It is an error to attempt to register multiple serializations for a serialized_name.
- Returns:
the same type passed as nodetype, so that this function can be used as a class decorator.
- Return type:
type[T]