jax.lax.axis_size#
- jax.lax.axis_size(axis_name)[source]#
Return the size of the mapped axis
axis_name
.- Parameters:
axis_name (AxisName) – hashable Python object used to name the mapped axis.
- Returns:
An integer representing the size.
- Return type:
For example, with 8 XLA devices available:
>>> from functools import partial >>> from jax.sharding import PartitionSpec as P >>> mesh = jax.make_mesh((8,), 'i') >>> @partial(jax.shard_map, mesh=mesh, in_specs=P('i'), out_specs=P()) ... def f(_): ... return lax.axis_size('i') ... >>> f(jnp.zeros(16)) Array(8, dtype=int32, weak_type=True) >>> mesh = jax.make_mesh((4, 2), ('i', 'j')) >>> @partial(jax.shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P()) ... def f(_): ... return lax.axis_size(('i', 'j')) ... >>> f(jnp.zeros((16, 8))) Array(8, dtype=int32, weak_type=True)