jax.eval_shape

内容

jax.eval_shape#

jax.eval_shape(fun, *args, **kwargs)[source]#

在不进行任何浮点运算的情况下计算 fun 的形状/数据类型。

此实用程序函数可用于执行形状推断。其输入/输出行为由以下定义:

def eval_shape(fun, *args, **kwargs):
  out = fun(*args, **kwargs)
  shape_dtype_struct = lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype)
  return jax.tree_util.tree_map(shape_dtype_struct, out)

但它不直接应用 fun(这可能很昂贵),而是使用 JAX 的抽象解释机制在不执行任何浮点运算的情况下评估形状。

使用 eval_shape() 还可以捕获形状错误,并会引发与评估 fun(*args, **kwargs) 相同的形状错误。

参数:
  • fun (Callable) – 应该评估其输出形状的函数。

  • *args – 数组、标量或(嵌套)标准 Python 容器(元组、列表、字典、命名元组,即 pytrees)的元组位置参数。由于仅访问 shapedtype 属性,因此可以使用 jax.ShapeDtypeStruct 或其他作为 ndarray 进行鸭子类型的容器(但请注意,鸭子类型对象不能是命名元组,因为这些对象被视为标准 Python 容器)。

  • **kwargs** – 一个关键字参数字典,包含数组、标量或(嵌套)标准 Python 容器(pytrees),这些容器包含这些类型的值。与 args 中一样,数组值只需要是鸭子类型,具有 shapedtype 属性即可。

返回值:

一个嵌套的 PyTree,其叶子包含 jax.ShapeDtypeStruct 对象。

返回类型:

out

例如

>>> import jax
>>> import jax.numpy as jnp
>>>
>>> f = lambda A, x: jnp.tanh(jnp.dot(A, x))
>>> A = jax.ShapeDtypeStruct((2000, 3000), jnp.float32)
>>> x = jax.ShapeDtypeStruct((3000, 1000), jnp.float32)
>>> out = jax.eval_shape(f, A, x)  # no FLOPs performed
>>> print(out.shape)
(2000, 1000)
>>> print(out.dtype)
float32

通过 eval_shape() 传递的所有参数将被视为动态的;静态参数可以通过闭包包含,例如使用 functools.partial()

>>> import jax
>>> from jax import lax
>>> from functools import partial
>>> import jax.numpy as jnp
>>>
>>> x = jax.ShapeDtypeStruct((1, 1, 28, 28), jnp.float32)
>>> kernel = jax.ShapeDtypeStruct((32, 1, 3, 3), jnp.float32)
>>>
>>> conv_same = partial(lax.conv_general_dilated, window_strides=(1, 1), padding="SAME")
>>> out = jax.eval_shape(conv_same, x, kernel)
>>> print(out.shape)
(1, 32, 28, 28)
>>> print(out.dtype)
float32