jax.numpy.einsum#
- jax.numpy.einsum(subscript: str, /, *operands: ArrayLike, out: None = None, optimize: str | bool | list[tuple[int, ...]] = 'optimal', precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, _dot_general: Callable[..., Array] = lax.dot_general) Array [source]#
- jax.numpy.einsum(arr: ArrayLike, axes: Sequence[Any], /, *operands: ArrayLike | Sequence[Any], out: None = None, optimize: str | bool | list[tuple[int, ...]] = 'optimal', precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, _dot_general: Callable[..., Array] = lax.dot_general) Array
爱因斯坦求和
JAX 实现的
numpy.einsum()
。einsum
是一个强大且通用的 API,用于计算一个或多个输入数组上的各种约简、内积、外积、轴重排序以及它们的组合。它具有一个有点复杂的重载 API;下面的参数反映了最常见的调用约定。下面的示例部分演示了一些替代的调用约定。- 参数:
subscripts – 包含以逗号分隔的轴名称的字符串。
*operands – 与下标相对应的一个或多个数组的序列。
optimize – 指定如何优化计算顺序。在 JAX 中,这默认为
"optimal"
,它通过 opt_einsum 包生成优化的表达式。其他选项为True
(与"optimal"
相同)、False
(未优化)或opt_einsum
支持的任何字符串,包括"auto"
、"greedy"
、"eager"
等。它也可以是预先计算的路径(参见einsum_path()
)。precision –
None
(默认值),这意味着后端的默认精度,一个Precision
枚举值(Precision.DEFAULT
、Precision.HIGH
或Precision.HIGHEST
)。preferred_element_type –
None
(默认值),这意味着输入类型的默认累积类型,或数据类型,指示将结果累积到并返回具有该数据类型的结果。out – JAX 不支持
_dot_general – 可选地覆盖
einsum
使用的dot_general
可调用对象。此参数为实验性参数,可能随时在不另行通知的情况下移除。
- 返回值:
包含爱因斯坦求和结果的数组。
示例
也许最好通过示例来演示
einsum
的机制。这里我们展示如何使用einsum
从一个或多个数组计算多个量。有关einsum
的更多讨论和示例,请参阅numpy.einsum()
的文档。>>> M = jnp.arange(16).reshape(4, 4) >>> x = jnp.arange(4) >>> y = jnp.array([5, 4, 3, 2])
向量积
>>> jnp.einsum('i,i', x, y) Array(16, dtype=int32) >>> jnp.vecdot(x, y) Array(16, dtype=int32)
以下是一些计算相同结果的替代
einsum
调用约定>>> jnp.einsum('i,i->', x, y) # explicit form Array(16, dtype=int32) >>> jnp.einsum(x, (0,), y, (0,)) # implicit form via indices Array(16, dtype=int32) >>> jnp.einsum(x, (0,), y, (0,), ()) # explicit form via indices Array(16, dtype=int32)
矩阵积
>>> jnp.einsum('ij,j->i', M, x) # explicit form Array([14, 38, 62, 86], dtype=int32) >>> jnp.matmul(M, x) Array([14, 38, 62, 86], dtype=int32)
以下是一些计算相同结果的替代
einsum
调用约定>>> jnp.einsum('ij,j', M, x) # implicit form Array([14, 38, 62, 86], dtype=int32) >>> jnp.einsum(M, (0, 1), x, (1,), (0,)) # explicit form via indices Array([14, 38, 62, 86], dtype=int32) >>> jnp.einsum(M, (0, 1), x, (1,)) # implicit form via indices Array([14, 38, 62, 86], dtype=int32)
外积
>>> jnp.einsum("i,j->ij", x, y) Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32) >>> jnp.outer(x, y) Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32)
计算外积的其他一些方法
>>> jnp.einsum("i,j", x, y) # implicit form Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32) >>> jnp.einsum(x, (0,), y, (1,), (0, 1)) # explicit form via indices Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32) >>> jnp.einsum(x, (0,), y, (1,)) # implicit form via indices Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32)
一维数组求和
>>> jnp.einsum("i->", x) # requires explicit form Array(6, dtype=int32) >>> jnp.einsum(x, (0,), ()) # explicit form via indices Array(6, dtype=int32) >>> jnp.sum(x) Array(6, dtype=int32)
沿轴求和
>>> jnp.einsum("...j->...", M) # requires explicit form Array([ 6, 22, 38, 54], dtype=int32) >>> jnp.einsum(M, (..., 0), (...,)) # explicit form via indices Array([ 6, 22, 38, 54], dtype=int32) >>> M.sum(-1) Array([ 6, 22, 38, 54], dtype=int32)
矩阵转置
>>> y = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.einsum("ij->ji", y) # explicit form Array([[1, 4], [2, 5], [3, 6]], dtype=int32) >>> jnp.einsum("ji", y) # implicit form Array([[1, 4], [2, 5], [3, 6]], dtype=int32) >>> jnp.einsum(y, (1, 0)) # implicit form via indices Array([[1, 4], [2, 5], [3, 6]], dtype=int32) >>> jnp.einsum(y, (0, 1), (1, 0)) # explicit form via indices Array([[1, 4], [2, 5], [3, 6]], dtype=int32) >>> jnp.transpose(y) Array([[1, 4], [2, 5], [3, 6]], dtype=int32)
矩阵对角线
>>> jnp.einsum("ii->i", M) Array([ 0, 5, 10, 15], dtype=int32) >>> jnp.diagonal(M) Array([ 0, 5, 10, 15], dtype=int32)
矩阵迹
>>> jnp.einsum("ii", M) Array(30, dtype=int32) >>> jnp.trace(M) Array(30, dtype=int32)
张量积
>>> x = jnp.arange(30).reshape(2, 3, 5) >>> y = jnp.arange(60).reshape(3, 4, 5) >>> jnp.einsum('ijk,jlk->il', x, y) # explicit form Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32) >>> jnp.tensordot(x, y, axes=[(1, 2), (0, 2)]) Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32) >>> jnp.einsum('ijk,jlk', x, y) # implicit form Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32) >>> jnp.einsum(x, (0, 1, 2), y, (1, 3, 2), (0, 3)) # explicit form via indices Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32) >>> jnp.einsum(x, (0, 1, 2), y, (1, 3, 2)) # implicit form via indices Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32)
链式点积
>>> w = jnp.arange(5, 9).reshape(2, 2) >>> x = jnp.arange(6).reshape(2, 3) >>> y = jnp.arange(-2, 4).reshape(3, 2) >>> z = jnp.array([[2, 4, 6], [3, 5, 7]]) >>> jnp.einsum('ij,jk,kl,lm->im', w, x, y, z) Array([[ 481, 831, 1181], [ 651, 1125, 1599]], dtype=int32) >>> jnp.einsum(w, (0, 1), x, (1, 2), y, (2, 3), z, (3, 4)) # implicit, via indices Array([[ 481, 831, 1181], [ 651, 1125, 1599]], dtype=int32) >>> w @ x @ y @ z # direct chain of matmuls Array([[ 481, 831, 1181], [ 651, 1125, 1599]], dtype=int32) >>> jnp.linalg.multi_dot([w, x, y, z]) Array([[ 481, 831, 1181], [ 651, 1125, 1599]], dtype=int32)