JAX Memories and Host Offloading#
This tutorial provides a practical introduction to host offloading techniques in JAX, focusing on:
Activation offloading
Parameter offloading
By applying offloading strategies, you can better manage memory resources and reduce memory pressure on your devices. To implement these strategies effectively, you’ll need to understand JAX’s core mechanisms for data placement and movement.
Building Blocks for Offloading#
JAX provides several key components for controlling where and how data are stored and moved between the host and the device memory. In the following sections, you’ll explore:
How to specify data distribution with sharding
How to control memory placement between host and device
How to manage data movement in jitted functions
NamedSharding and Memory Kinds#
NamedSharding
defines how data are distributed across devices. It includes:
Basic data distribution configuration
memory_kind
parameter for specifying memory type (device
orpinned_host
)By default,
memory_kind
is set todevice
memorywith_memory_kind
method for creating new sharding with modified memory type
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
import numpy as np
# Create mesh
# 1x1 mesh represents a single device with two named dimensions (x and y)
mesh = Mesh(np.array(jax.devices()[0]).reshape(1, 1), ('x', 'y'))
# Device sharding - partitions data along x and y dimensions
s_dev = NamedSharding(mesh, P('x', 'y'), memory_kind="device")
# Host sharding - same partitioning but in pinned host memory
s_host = s_dev.with_memory_kind('pinned_host')
print(s_dev) # Shows device memory sharding
print(s_host) # Shows pinned host memory sharding
NamedSharding(mesh=Mesh('x': 1, 'y': 1), spec=PartitionSpec('x', 'y'), memory_kind=device)
NamedSharding(mesh=Mesh('x': 1, 'y': 1), spec=PartitionSpec('x', 'y'), memory_kind=pinned_host)
Data Placement with device_put#
jax.device_put()
is a function that explicitly transfers arrays to a specified memory location according to a sharding specification.
# Create a 2x4 array
arr = jnp.arange(8.0).reshape(2, 4)
# Move arrays to different memory locations based on sharding objects
arr_host = jax.device_put(arr, s_host) # Places in pinned host memory
arr_dev = jax.device_put(arr, s_dev) # Places in device memory
# Verify memory locations
print(arr_host.sharding.memory_kind) # Output: pinned_host
print(arr_dev.sharding.memory_kind) # Output: device
pinned_host
device
Output Sharding Controls#
Shardings determine how data is split across devices. JAX provides out_shardings
to control how output arrays are partitioned when leaving a jitted function.
Key Features:
Can differ from input sharding
Allows different memory kinds for outputs
Examples:
Device Output Sharding#
f = jax.jit(lambda x:x, out_shardings=s_dev)
out_dev = f(arr_host)
print("Result value of H2D: \n", out_dev)
Result value of H2D:
[[0. 1. 2. 3.]
[4. 5. 6. 7.]]
Moving data from host to device memory when needed for computation is the essence of host offloading. Use jax.device_put()
to perform this transfer in this example to optimize performance.
# Instead of the lambda function, you can define add_func to explicitly
# move data to device before computation
def add_func(x): # Move data to device and add one
x = jax.device_put(x, s_dev)
return x + 1
f = jax.jit(add_func, out_shardings=s_dev)
out_dev = f(arr_host)
print("Result value of H2D and add 1 in device memory: \n", out_dev)
Result value of H2D and add 1 in device memory:
[[1. 2. 3. 4.]
[5. 6. 7. 8.]]
Host Output Sharding#
f = jax.jit(lambda x: x, out_shardings=s_dev)
out_host = f(arr_host) # Input arrays in hte device memory while output arrays in the host memory
print("Result value of D2H: \n", out_host)
Result value of D2H:
[[0. 1. 2. 3.]
[4. 5. 6. 7.]]
Activation Offloading#
The detailed coverage of activation offloading can be found in the Gradient checkpointing with jax.checkpoint (jax.remat) tutorial. Activation offloading helps manage memory by moving intermediate activations to host memory after the forward pass, and bringing them back to device memory during the backward pass when needed for gradient computation.
To implement activation offloading effectively, you need to understand checkpoint names and policies. Here’s how they work in a simple example:
Checkpoint Names#
The checkpoint_name()
function allows you to label activations for memory management during computation. Here’s a simple example:
from jax.ad_checkpoint import checkpoint_name
def layer(x, w):
w1, w2 = w
x = checkpoint_name(x, "x")
y = x @ w1
return y @ w2, None
This example shows:
A simple neural network layer with two matrix multiplications
Labeling of input activation x with identifier
"x"
Sequential operations:
First multiplication:
x @ w1
Second multiplication:
y @ w2
The checkpoint name helps the system decide whether to:
Keep the activation in device memory or
Offload it to host memory during computation
This pattern is common in neural networks, where multiple transformations are applied sequentially to input data.
Checkpoint Policies#
The jax.remat()
transformation manages memory by handling intermediate values through three strategies:
Recomputing during backward pass (default behavior)
Storing on device
Offloading to host memory after forward pass and loading back during backward pass
Example of setting an offloading checkpoint policy:
from jax import checkpoint_policies as cp
policy = cp.save_and_offload_only_these_names(
names_which_can_be_saved=[], # No values stored on device
names_which_can_be_offloaded=["x"], # Offload activations labeled "x"
offload_src="device", # Move from device memory
offload_dst="pinned_host" # To pinned host memory
)
Since jax.lax.scan()
is commonly used in JAX for handling sequential operations (like RNNs or transformers), you need to know how to apply your offloading strategy in this context.
Key components:
jax.remat()
applies our checkpoint policy to the layer functionprevent_cse=False
enables XLA’s common subexpression elimination for better performancejax.lax.scan()
iterates the rematerialized layer along an axis
def scanned(w, x):
remat_layer = jax.remat(layer,
policy=policy, # Use our offloading policy
prevent_cse=False) # Allow CSE optimizations
result = jax.lax.scan(remat_layer, x, w)[0]
return jnp.sum(result)
# Initialize input and weights with small values (0.0001)
input = jnp.ones((256, 256), dtype=jnp.float32) * 0.001 # Input matrix: 256 x 256
w1 = jnp.ones((10, 256, 1024), dtype=jnp.float32) * 0.001 # 10 layers of 256 x 1024 matrices
w2 = jnp.ones((10, 1024, 256), dtype=jnp.float32) * 0.001 # 10 layers of 1024 x 256 matrices
# Compile and compute gradients of the scanned function
f = jax.jit(jax.grad(scanned)) # Apply JIT compilation to gradient computation
result_activation = f((w1, w2), input) # Execute the function with weights and input
print("Sample of results: ", result_activation[0][0, 0, :5])
Sample of results: [3.7363498e-07 3.7363498e-07 3.7363498e-07 3.7363498e-07 3.7363498e-07]
Summary of Activation Offloading#
Activation offloading provides a powerful way to manage memory in large computations by:
Using checkpoint names to mark specific activations
Applying policies to control where and how activations are stored
Supporting common JAX patterns like scan operations
Moving selected activations to host memory when device memory is under budget
This approach is particularly useful when working with large models that would otherwise exceed device memory capacity.
Parameter Offloading#
Model parameters (also known as weights) can be offloaded to the host memory to optimize device memory usage during initialization. This is achieved by using jax.jit()
with a sharding strategy that specifies host memory kind.
While parameter offloading and activation offloading are distinct memory optimization techniques, the following example demonstrates parameter offloading built upon the activation offloading implementation shown earlier.
Parameter Placement for Computation#
Different from the earlier layer
function, jax.device_put()
is applied to move parameter w1
and w2
to the device before the matrix multiplications. This ensures the parameters are available on the device for both forward and backward passes.
Note that the activation offloading implementation remains unchanged, using the same:
Checkpoint name
"x"
Checkpoint policy
scanned
function combiningjax.remat()
andjax.lax.scan()
Parameter Initialization with Host Offloading#
During the initialization, parameter w1
and w2
are placed on host memory before being passed to the jax.jit()
function f
, while keeping the input
variable on the device.
# Hybrid version: Both activation and parameter offloading
def hybrid_layer(x, w):
# Move model parameters w1 and w2 to host memory via device_put
w1, w2 = jax.tree.map(lambda x: jax.device_put(x, s_dev), w)
x = checkpoint_name(x, "x") # Offload activation x to host memory
y = x @ w1
return y @ w2, None
def hybrid_scanned(w, x):
remat_layer = jax.remat(hybrid_layer, # Use hybrid_layer instead of layer
policy=policy, # Use offloading policy
prevent_cse=False) # Allow CSE optimizations
result = jax.lax.scan(remat_layer, x, w)[0]
return jnp.sum(result)
# Move model parameters w1 and w2 to the host via device_put
# Initialize input and weights with small values (0.0001)
wh1 = jax.device_put(w1, s_host)
wh2 = jax.device_put(w2, s_host)
# Compile and compute gradients of the scanned function
f = jax.jit(jax.grad(hybrid_scanned)) # Apply JIT compilation to gradient computation
result_both = f((wh1, wh2), input) # Execute with both activation and parameter offloading
# Verify numerical correctness
are_close = jnp.allclose(
result_activation[0], # Result from activation offloading only
result_both[0], # Result from both activation and parameter offloading
rtol=1e-5,
atol=1e-5
)
print(f"Results match within tolerance: {are_close}")
Results match within tolerance: True
The matching results verify that initializing parameters on host memory maintains computational correctness.
Limitation of Parameter Offloading#
jax.lax.scan()
is crucial for effective parameter management. Using an explicit for loop would cause parameters to continuously occupy device memory, resulting in the same memory usage as without parameter offloading. While jax.lax.scan()
allows specifying the scan axis, parameter offloading currently works only when scanning over axis 0. Scanning over other axes generates a transpose
operation during compilation before returning parameters to the device, which is expensive and not supported on all platforms.
Tools for Host Offloading#
For device memory analysis, refer to :doc:device_memory_profiling
. The profiling tools described in Profiling and Tracing can help measure memory savings and performance impact from host offloading.