jax.export.symbolic_args_specs#
- jax.export.symbolic_args_specs(args, shapes_specs, constraints=(), scope=None)[source]#
Constructs a pytree of jax.ShapeDtypeSpec arguments specs for export.
See the documentation of
jax.export.symbolic_shape()
and the [shape polymorphism documentation](https://docs.jax.dev/en/latest/export/shape_poly.html) for details.- Parameters:
args – a pytree of arguments. These can be jax.Array, or jax.ShapeDTypeSpec. They are used to learn the pytree structure of the arguments, their dtypes, and to fill-in the actual shapes where the shapes_specs contains placeholders. Note that only the shape dimensions for which shapes_specs is a placeholder are used from args.
shapes_specs – should be None (all arguments have static shapes), a single string (see shape_spec for
jax.export.symbolic_shape()
; applies to all arguments), or a pytree matching a prefix of the args. See [how optional parameters are matched to arguments](https://docs.jax.dev/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).constraints (Sequence[str]) – as for
jax.export.symbolic_shape()
.scope (SymbolicScope | None | None) – as for
jax.export.symbolic_shape()
.
- Returns: a pytree of jax.ShapeDTypeStruct matching the args with the shapes
replaced with symbolic dimensions as specified by shapes_specs.