jax.flatten_util.ravel_pytree#
- jax.flatten_util.ravel_pytree(pytree)[source]#
Ravel (flatten) a pytree of arrays down to a 1D array.
- Parameters:
pytree – a pytree of arrays and scalars to ravel.
- Returns:
A pair where the first element is a 1D array representing the flattened and concatenated leaf values, with dtype determined by promoting the dtypes of leaf values, and the second element is a callable for unflattening a 1D vector of the same length back to a pytree of the same structure as the input
pytree
. If the input pytree is empty (i.e. has no leaves) then as a convention a 1D empty array of dtype float32 is returned in the first component of the output.
For details on dtype promotion, see https://docs.jax.dev/en/latest/type_promotion.html.