jax.make_jaxpr#
- jax.make_jaxpr(fun: Callable, static_argnums: int | Iterable[int] = (), axis_env: Sequence[tuple[AxisName, int]] | None = None, return_shape: Literal[False] = False, abstracted_axes: Any | None = None) Callable[..., core.ClosedJaxpr] [源代码]#
- jax.make_jaxpr(fun: Callable, static_argnums: int | Iterable[int] = (), axis_env: Sequence[tuple[AxisName, int]] | None = None, return_shape: Literal[True] = False, abstracted_axes: Any | None = None) Callable[..., tuple[core.ClosedJaxpr, Any]]
创建一个函数,该函数在给定示例参数的情况下生成其 jaxpr。
- 参数:
fun – 要计算其
jaxpr
的函数。其位置参数和返回值应为数组、标量或它们的标准 Python 容器(元组/列表/字典)。static_argnums – 请参阅
jax.jit()
文档字符串。axis_env – 可选,成对的序列,其中第一个元素是轴名称,第二个元素是表示具有该名称的映射轴大小的正整数。此参数在降低涉及并行通信集合的函数时很有用,它指定了
jax.pmap()
应用将设置的轴名称/大小环境。return_shape – 可选布尔值,默认为
False
。如果为True
,则包装的函数返回一个对,其中第一个元素是fun
的ClosedJaxpr
表示形式,第二个元素是一个与fun
的输出具有相同结构的 pytree,其中叶子是具有shape
和dtype
属性的对象,表示输出叶子的相应类型。
- 返回值:
一个
fun
的包装版本,当应用于示例参数时,返回fun
在这些参数上的ClosedJaxpr
表示形式。如果参数return_shape
为True
,则返回的函数将返回一个对,其中第一个元素是fun
的ClosedJaxpr
表示形式,第二个元素是一个 pytree,表示fun
输出的结构、形状、dtypes 和命名形状。
jaxpr
是 JAX 用于程序跟踪的中间表示形式。jaxpr
语言基于具有 let 绑定的简单类型一阶 lambda 演算。make_jaxpr()
调整函数以返回其jaxpr
,我们可以检查它来了解 JAX 在内部执行的操作。返回的jaxpr
是fun
抽象到ShapedArray
级别的跟踪。内部存在其他抽象级别。我们在此不详细描述
jaxpr
语言的语义,而是给出一些示例。>>> import jax >>> >>> def f(x): return jax.numpy.sin(jax.numpy.cos(x)) >>> print(f(3.0)) -0.83602 >>> jax.make_jaxpr(f)(3.0) { lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) } >>> jax.make_jaxpr(jax.grad(f))(3.0) { lambda ; a:f32[]. let b:f32[] = cos a c:f32[] = sin a _:f32[] = sin b d:f32[] = cos b e:f32[] = mul 1.0 d f:f32[] = neg e g:f32[] = mul f c in (g,) }