jax.export.symbolic_args_specs#

jax.export.symbolic_args_specs(args, shapes_specs, constraints=(), scope=None)[源代码]#

export 构建 jax.ShapeDtypeSpec 参数规范的 pytree。

详情请参阅 jax.export.symbolic_shape() 的文档以及 [形状多态性文档](https://jax.ac.cn/en/latest/export/shape_poly.html)。

参数:
返回: 与 args 匹配的 jax.ShapeDTypeStruct 的 pytree,其形状

根据 shapes_specs 的指定替换为符号维度。