jax.copy_to_host_async#
- jax.copy_to_host_async(x)[source]#
Tries to call a
copy_to_host_async
method on pytree leaves.For each leaf this method will try to call the
copy_to_host_async
method on the leaf. If the leaf is not a JAX array, or if the leaf does not have acopy_to_host_async
method, then this method will do nothing to the leaf.- Parameters:
x – a pytree, usually with at least some JAX array instances at its leaves.
- Returns:
A pytree with the same structure and values of the input, where the host copy of the values of all JAX array leaves are started.