jax.export.symbolic_args_specs#
- jax.export.symbolic_args_specs(args, shapes_specs, constraints=(), scope=None)[源代码]#
为 export 构建 jax.ShapeDtypeSpec 参数规范的 pytree。
有关详细信息,请参阅
jax.export.symbolic_shape()
和 [形状多态性文档](https://jax.ac.cn/en/latest/export/shape_poly.html)的文档。- 参数:
args – 参数的 pytree。这些可以是 jax.Array 或 jax.ShapeDTypeSpec。它们用于学习参数的 pytree 结构、它们的 dtypes 以及在 shapes_specs 包含占位符的位置填充实际形状。请注意,仅使用 args 中 shapes_specs 是占位符的形状维度。
shapes_specs – 应该为 None (所有参数都有静态形状),单个字符串 (请参阅 shape_spec,了解
jax.export.symbolic_shape()
;应用于所有参数),或匹配 args 前缀的 pytree。 请参阅 [如何将可选参数匹配到参数](https://jax.ac.cn/en/latest/pytrees.html#applying-optional-parameters-to-pytrees)。constraints (Sequence[str]) – 与
jax.export.symbolic_shape()
的用法相同。scope (SymbolicScope | None | None) – 与
jax.export.symbolic_shape()
的用法相同。
- 返回值:与 args 匹配的 jax.ShapeDTypeStruct 的 pytree,其中形状
根据 shapes_specs 的指定,替换为符号维度。