jax.numpy.linalg.multi_dot#
- jax.numpy.linalg.multi_dot(arrays, *, precision=None)[源代码]#
高效计算一系列数组之间的矩阵乘积。
JAX 实现的
numpy.linalg.multi_dot()
。JAX 内部使用 opt_einsum 库来计算最有效的运算顺序。
- 参数:
arrays (Sequence[ArrayLike]) – 数组序列。除了第一个和最后一个可以是一维之外,其他都必须是二维的。
precision (PrecisionLike | None) – 可以是
None
(默认),表示后端的默认精度;或者是一个Precision
枚举值(Precision.DEFAULT
,Precision.HIGH
或Precision.HIGHEST
)。
- 返回:
一个数组,表示等效于
reduce(jnp.matmul, arrays)
的结果,但以最佳顺序进行计算。- 返回类型:
此函数的存在是因为计算矩阵乘法操作序列的成本可能因操作评估的顺序而差异很大。对于单个矩阵乘法,计算矩阵乘积所需的浮点运算次数 (flops) 可以这样近似
>>> def approx_flops(x, y): ... # for 2D x and y, with x.shape[1] == y.shape[0] ... return 2 * x.shape[0] * x.shape[1] * y.shape[1]
假设我们有三个矩阵,我们想按顺序相乘
>>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3) >>> x = jax.random.normal(key1, shape=(200, 5)) >>> y = jax.random.normal(key2, shape=(5, 100)) >>> z = jax.random.normal(key3, shape=(100, 10))
由于矩阵乘法的结合律,我们可以使用两种顺序来计算乘积
x @ y @ z
,并且在浮点精度范围内,两者都产生等效的输出>>> result1 = (x @ y) @ z >>> result2 = x @ (y @ z) >>> jnp.allclose(result1, result2, atol=1E-4) Array(True, dtype=bool)
但是它们的计算成本差异很大
>>> print("(x @ y) @ z flops:", approx_flops(x, y) + approx_flops(x @ y, z)) (x @ y) @ z flops: 600000 >>> print("x @ (y @ z) flops:", approx_flops(y, z) + approx_flops(x, y @ z)) x @ (y @ z) flops: 30000
第二种方法在估计的浮点运算次数方面效率高出约 20 倍!
multi_dot
是一个函数,可以自动为这类问题选择最快的计算路径>>> result3 = jnp.linalg.multi_dot([x, y, z]) >>> jnp.allclose(result1, result3, atol=1E-4) Array(True, dtype=bool)
我们可以使用 JAX 的 提前降低和编译 工具来估计每种方法的总浮点运算次数,并确认
multi_dot
选择的是更高效的选项>>> jax.jit(lambda x, y, z: (x @ y) @ z).lower(x, y, z).cost_analysis()['flops'] 600000.0 >>> jax.jit(lambda x, y, z: x @ (y @ z)).lower(x, y, z).cost_analysis()['flops'] 30000.0 >>> jax.jit(jnp.linalg.multi_dot).lower([x, y, z]).cost_analysis()['flops'] 30000.0