jax.shard_map#
- jax.shard_map(f=None, /, *, out_specs, axis_names={}, in_specs=None, mesh=None, check_vma=True)[source]#
Map a function over shards of data using a mesh of devices.
See the docs at https://docs.jax.dev/en/latest/notebooks/shard_map.html.
- Parameters:
f – callable to be mapped. Each application of
f
, or “instance” off
, takes as input a shard of the mapped-over arguments and produces a shard of the output.mesh (Mesh | AbstractMesh | None | None) – (optional, default None) a
jax.sharding.Mesh
representing the array of devices over which to shard the data and on which to execute instances off
. The names of theMesh
can be used in collective communication operations inf
. If mesh is None, it will be inferred from the context which can be set via jax.sharding.use_mesh context manager.in_specs (Specs | None | None) – (optional, default None) a pytree with
jax.sharding.PartitionSpec
instances as leaves, with a tree structure that is a tree prefix of the args tuple to be mapped over. Similar tojax.sharding.NamedSharding
, eachPartitionSpec
represents how the corresponding argument (or subtree of arguments) should be sharded along the named axes ofmesh
. In eachPartitionSpec
, mentioning amesh
axis name at a position expresses sharding the corresponding argument array axis along that positional axis; not mentioning an axis name expresses replication. IfNone
, all mesh axes must be of type Explicit, in which case the in_specs are inferred from the argument types.out_specs (Specs) – a pytree with
PartitionSpec
instances as leaves, with a tree structure that is a tree prefix of the output off
. EachPartitionSpec
represents how the corresponding output shards should be concatenated. In eachPartitionSpec
, mentioning amesh
axis name at a position expresses concatenation of that mesh axis’s shards along the corresponding positional axis; not mentioning amesh
axis name expresses a promise that the output values are equal along that mesh axis, and that rather than concatenating only a single value should be produced.axis_names (Set[AxisName]) – (optional, default set()) set of axis names from
mesh
over which the functionf
is manual. If empty,f
, is manual over all mesh axes.check_vma (bool) – (optional) boolean (default True) representing whether to enable additional validity checks and automatic differentiation optimizations. The validity checks concern whether any mesh axis names not mentioned in
out_specs
are consistent with how the outputs off
are replicated.
- Returns:
A callable representing a mapped version of
f
, which accepts positional arguments corresponding to those off
and produces output corresponding to that off
.