jax.numpy.linalg.multi_dot#

jax.numpy.linalg.multi_dot(arrays, *, precision=None)[源代码]#

高效计算一系列数组之间的矩阵乘积。

numpy.linalg.multi_dot() 的 JAX 实现。

JAX 内部使用 opt_einsum 库来计算最高效的运算顺序。

参数:
  • arrays (Sequence[ArrayLike]) – 数组序列。所有数组必须是二维的,除了第一个和最后一个可以是是一维的。

  • precision (PrecisionLike | None) – 可以是 None (默认值),表示使用后端的默认精度;也可以是一个 Precision 枚举值 (Precision.DEFAULT, Precision.HIGHPrecision.HIGHEST)。

返回值:

一个数组,表示与 reduce(jnp.matmul, arrays) 等效的结果,但以最优顺序进行求值。

返回类型:

Array

此函数存在的原因是,计算一系列矩阵乘法操作的成本可能因操作的求值顺序而差异巨大。 对于单个矩阵乘法,计算矩阵乘积所需的浮点运算次数(flops)可以通过以下方式进行近似计算

>>> def approx_flops(x, y):
...   # for 2D x and y, with x.shape[1] == y.shape[0]
...   return 2 * x.shape[0] * x.shape[1] * y.shape[1]

假设我们有三个矩阵,我们想按顺序进行相乘

>>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3)
>>> x = jax.random.normal(key1, shape=(200, 5))
>>> y = jax.random.normal(key2, shape=(5, 100))
>>> z = jax.random.normal(key3, shape=(100, 10))

由于矩阵乘法的结合律,我们可以用两种顺序来求值乘积 x @ y @ z,并且两种方式都会产生等效的输出(在浮点精度范围内)。

>>> result1 = (x @ y) @ z
>>> result2 = x @ (y @ z)
>>> jnp.allclose(result1, result2, atol=1E-4)
Array(True, dtype=bool)

但是它们的计算成本差异很大

>>> print("(x @ y) @ z flops:", approx_flops(x, y) + approx_flops(x @ y, z))
(x @ y) @ z flops: 600000
>>> print("x @ (y @ z) flops:", approx_flops(y, z) + approx_flops(x, y @ z))
x @ (y @ z) flops: 30000

第二种方法在估计的浮点运算次数方面效率高出约 20 倍!

multi_dot 是一个函数,它将自动为这类问题选择最快的计算路径

>>> result3 = jnp.linalg.multi_dot([x, y, z])
>>> jnp.allclose(result1, result3, atol=1E-4)
Array(True, dtype=bool)

我们可以使用 JAX 的 预先降低和编译 工具来估计每种方法的总浮点运算次数,并确认 multi_dot 选择的是更高效的选项

>>> jax.jit(lambda x, y, z: (x @ y) @ z).lower(x, y, z).cost_analysis()['flops']
600000.0
>>> jax.jit(lambda x, y, z: x @ (y @ z)).lower(x, y, z).cost_analysis()['flops']
30000.0
>>> jax.jit(jnp.linalg.multi_dot).lower([x, y, z]).cost_analysis()['flops']
30000.0