jax.export.symbolic_args_specs

jax.export.symbolic_args_specs#

jax.export.symbolic_args_specs(args, shapes_specs, constraints=(), scope=None, symbolic_constraints=(), symbolic_scope=None)[source]#

export 构造 jax.ShapeDtypeSpec 参数规范的 pytree。

有关详细信息,请参阅 jax.export.symbolic_shape() 的文档和 [形状多态性文档](https://jax.ac.cn/en/latest/export/shape_poly.html)。

参数:
返回:与 args 匹配的 jax.ShapeDTypeStruct 的 pytree,其中形状

已替换为由 shapes_specs 指定的符号维度。