jax.export.symbolic_args_specs#

jax.export.symbolic_args_specs(args, shapes_specs, constraints=(), scope=None)[source]#

Constructs a pytree of jax.ShapeDtypeSpec arguments specs for export.

See the documentation of jax.export.symbolic_shape() and the [shape polymorphism documentation](https://docs.jax.dev/en/latest/export/shape_poly.html) for details.

Parameters:
Returns: a pytree of jax.ShapeDTypeStruct matching the args with the shapes

replaced with symbolic dimensions as specified by shapes_specs.