jax.export.symbolic_args_specs#
- jax.export.symbolic_args_specs(args, shapes_specs, constraints=(), scope=None, symbolic_constraints=(), symbolic_scope=None)[source]#
为 export 构造 jax.ShapeDtypeSpec 参数规范的 pytree。
有关详细信息,请参阅
jax.export.symbolic_shape()
的文档和 [形状多态性文档](https://jax.ac.cn/en/latest/export/shape_poly.html)。- 参数:
args – 参数的 pytree。这些可以是 jax.Array 或 jax.ShapeDTypeSpec。它们用于学习参数的 pytree 结构、其数据类型,并填充 shapes_specs 中包含占位符的实际形状。请注意,仅使用 shapes_specs 为占位符的形状维度来自 args。
shapes_specs – 应为 None(所有参数具有静态形状),单个字符串(有关 shape_spec,请参阅
jax.export.symbolic_shape()
;适用于所有参数),或与 args 的前缀匹配的 pytree。请参阅 [如何将可选参数与参数匹配](https://jax.ac.cn/en/latest/pytrees.html#applying-optional-parameters-to-pytrees)。constraints (Sequence[str]) – 与
jax.export.symbolic_shape()
相同。scope (SymbolicScope | None | None) – 与
jax.export.symbolic_shape()
相同。symbolic_constraints (Sequence[str]) – 已弃用,请使用 constraints。
symbolic_scope (SymbolicScope | None | None) – 已弃用,请使用 scope。
- 返回:与 args 匹配的 jax.ShapeDTypeStruct 的 pytree,其中形状
已替换为由 shapes_specs 指定的符号维度。