jax.numpy.einsum#
- jax.numpy.einsum(subscripts, /, *operands, out=None, optimize='optimal', precision=None, preferred_element_type=None, _dot_general=<function dot_general>, out_type=None)[源代码]#
爱因斯坦求和
JAX 实现的
numpy.einsum()
。einsum
是一个强大且通用的 API,用于计算各种缩减、内积、外积、轴重排序以及一个或多个输入数组的组合。它有一个稍微复杂的重载 API;下面的参数反映了最常见的调用约定。下面的示例部分演示了一些替代的调用约定。- 参数:
subscripts – 包含以逗号分隔的轴名称的字符串。
*operands – 与下标对应的一个或多个数组序列。
optimize (str | bool | list[tuple[int, ...]]) – 指定如何优化计算顺序。在 JAX 中,默认为
"optimal"
,它通过 opt_einsum 包生成优化的表达式。其他选项包括True
(与"optimal"
相同)、False
(未优化)或opt_einsum
支持的任何字符串,其中包括"auto"
、"greedy"
、"eager"
等。它也可以是预先计算的路径(请参阅einsum_path()
)。precision (PrecisionLike | None) –
None
(默认),表示后端的默认精度,或者Precision
枚举值(Precision.DEFAULT
、Precision.HIGH
或Precision.HIGHEST
)。preferred_element_type (DTypeLike | None | None) –
None
(默认),表示输入类型的默认累积类型,或者一个数据类型,表示将结果累积到该数据类型并返回具有该数据类型的结果。out (None | None) – JAX 不支持
_dot_general (Callable[..., Array]) – 可选择覆盖
einsum
使用的dot_general
可调用对象。此参数是实验性的,可能会在不发出警告的情况下随时删除。
- 返回:
包含爱因斯坦求和结果的数组。
- 返回类型:
示例
einsum
的机制最好通过示例来演示。在这里,我们展示如何使用einsum
从一个或多个数组计算多个量。有关einsum
的更多讨论和示例,请参阅numpy.einsum()
的文档。>>> M = jnp.arange(16).reshape(4, 4) >>> x = jnp.arange(4) >>> y = jnp.array([5, 4, 3, 2])
向量积
>>> jnp.einsum('i,i', x, y) Array(16, dtype=int32) >>> jnp.vecdot(x, y) Array(16, dtype=int32)
以下是一些替代的
einsum
调用约定来计算相同的结果>>> jnp.einsum('i,i->', x, y) # explicit form Array(16, dtype=int32) >>> jnp.einsum(x, (0,), y, (0,)) # implicit form via indices Array(16, dtype=int32) >>> jnp.einsum(x, (0,), y, (0,), ()) # explicit form via indices Array(16, dtype=int32)
矩阵积
>>> jnp.einsum('ij,j->i', M, x) # explicit form Array([14, 38, 62, 86], dtype=int32) >>> jnp.matmul(M, x) Array([14, 38, 62, 86], dtype=int32)
以下是一些替代的
einsum
调用约定来计算相同的结果>>> jnp.einsum('ij,j', M, x) # implicit form Array([14, 38, 62, 86], dtype=int32) >>> jnp.einsum(M, (0, 1), x, (1,), (0,)) # explicit form via indices Array([14, 38, 62, 86], dtype=int32) >>> jnp.einsum(M, (0, 1), x, (1,)) # implicit form via indices Array([14, 38, 62, 86], dtype=int32)
外积
>>> jnp.einsum("i,j->ij", x, y) Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32) >>> jnp.outer(x, y) Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32)
计算外积的其他一些方法
>>> jnp.einsum("i,j", x, y) # implicit form Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32) >>> jnp.einsum(x, (0,), y, (1,), (0, 1)) # explicit form via indices Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32) >>> jnp.einsum(x, (0,), y, (1,)) # implicit form via indices Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32)
1D 数组求和
>>> jnp.einsum("i->", x) # requires explicit form Array(6, dtype=int32) >>> jnp.einsum(x, (0,), ()) # explicit form via indices Array(6, dtype=int32) >>> jnp.sum(x) Array(6, dtype=int32)
沿轴求和
>>> jnp.einsum("...j->...", M) # requires explicit form Array([ 6, 22, 38, 54], dtype=int32) >>> jnp.einsum(M, (..., 0), (...,)) # explicit form via indices Array([ 6, 22, 38, 54], dtype=int32) >>> M.sum(-1) Array([ 6, 22, 38, 54], dtype=int32)
矩阵转置
>>> y = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.einsum("ij->ji", y) # explicit form Array([[1, 4], [2, 5], [3, 6]], dtype=int32) >>> jnp.einsum("ji", y) # implicit form Array([[1, 4], [2, 5], [3, 6]], dtype=int32) >>> jnp.einsum(y, (1, 0)) # implicit form via indices Array([[1, 4], [2, 5], [3, 6]], dtype=int32) >>> jnp.einsum(y, (0, 1), (1, 0)) # explicit form via indices Array([[1, 4], [2, 5], [3, 6]], dtype=int32) >>> jnp.transpose(y) Array([[1, 4], [2, 5], [3, 6]], dtype=int32)
矩阵对角线
>>> jnp.einsum("ii->i", M) Array([ 0, 5, 10, 15], dtype=int32) >>> jnp.diagonal(M) Array([ 0, 5, 10, 15], dtype=int32)
矩阵迹
>>> jnp.einsum("ii", M) Array(30, dtype=int32) >>> jnp.trace(M) Array(30, dtype=int32)
张量积
>>> x = jnp.arange(30).reshape(2, 3, 5) >>> y = jnp.arange(60).reshape(3, 4, 5) >>> jnp.einsum('ijk,jlk->il', x, y) # explicit form Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32) >>> jnp.tensordot(x, y, axes=[(1, 2), (0, 2)]) Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32) >>> jnp.einsum('ijk,jlk', x, y) # implicit form Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32) >>> jnp.einsum(x, (0, 1, 2), y, (1, 3, 2), (0, 3)) # explicit form via indices Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32) >>> jnp.einsum(x, (0, 1, 2), y, (1, 3, 2)) # implicit form via indices Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32)
链式点积
>>> w = jnp.arange(5, 9).reshape(2, 2) >>> x = jnp.arange(6).reshape(2, 3) >>> y = jnp.arange(-2, 4).reshape(3, 2) >>> z = jnp.array([[2, 4, 6], [3, 5, 7]]) >>> jnp.einsum('ij,jk,kl,lm->im', w, x, y, z) Array([[ 481, 831, 1181], [ 651, 1125, 1599]], dtype=int32) >>> jnp.einsum(w, (0, 1), x, (1, 2), y, (2, 3), z, (3, 4)) # implicit, via indices Array([[ 481, 831, 1181], [ 651, 1125, 1599]], dtype=int32) >>> w @ x @ y @ z # direct chain of matmuls Array([[ 481, 831, 1181], [ 651, 1125, 1599]], dtype=int32) >>> jnp.linalg.multi_dot([w, x, y, z]) Array([[ 481, 831, 1181], [ 651, 1125, 1599]], dtype=int32)