jax.export.symbolic_args_specs#
- jax.export.symbolic_args_specs(args, shapes_specs, constraints=(), scope=None)[源代码]#
为 export 构建 jax.ShapeDtypeSpec 参数规范的 pytree。
详情请参阅
jax.export.symbolic_shape()
的文档以及 [形状多态性文档](https://jax.ac.cn/en/latest/export/shape_poly.html)。- 参数:
args – 参数的 pytree。可以是 jax.Array 或 jax.ShapeDTypeSpec。它们用于学习参数的 pytree 结构、dtypes,并在 shapes_specs 包含占位符的情况下填充实际形状。请注意,只有 shapes_specs 为占位符的形状维度才会使用 args 中的值。
shapes_specs – 应该是 None (所有参数都有静态形状),一个字符串(请参阅
jax.export.symbolic_shape()
的 shape_spec;适用于所有参数),或与 args 的前缀匹配的 pytree。请参阅 [可选参数如何与参数匹配](https://jax.ac.cn/en/latest/pytrees.html#applying-optional-parameters-to-pytrees)。constraints (Sequence[str]) – 与
jax.export.symbolic_shape()
相同。scope (SymbolicScope | None | None) – 与
jax.export.symbolic_shape()
相同。
- 返回: 与 args 匹配的 jax.ShapeDTypeStruct 的 pytree,其形状
根据 shapes_specs 的指定替换为符号维度。