jax.numpy.cross#

jax.numpy.cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None)[源代码]#

计算两个数组的(批处理)叉积。

numpy.cross() 的 JAX 实现。

这将计算二维或三维的叉积,

\[c = a \times b\]

在三维空间中,c 是一个长度为 3 的数组。在二维空间中,c 是一个标量。

参数:
  • a – N 维数组。a.shape[axisa] 表示叉积的维度,必须为 2 或 3。

  • b – N 维数组。必须有 b.shape[axisb] == a.shape[axisb],并且 ab 的其他维度必须具有广播兼容性。

  • axisa (int) – 指定 a 中计算叉积的轴。

  • axisb (int) – 指定 b 中计算叉积的轴。

  • axisc (int) – 指定 c 中存储叉积结果的轴。

  • axis (int | None) – 如果指定,则使用单个值覆盖 axisaaxisbaxisc

返回:

数组 c,包含 ab 沿指定轴的(批处理)叉积。

另请参阅

示例

二维叉积返回一个标量

>>> a = jnp.array([1, 2])
>>> b = jnp.array([3, 4])
>>> jnp.cross(a, b)
Array(-2, dtype=int32)

三维叉积返回一个长度为 3 的向量

>>> a = jnp.array([1, 2, 3])
>>> b = jnp.array([4, 5, 6])
>>> jnp.cross(a, b)
Array([-3,  6, -3], dtype=int32)

对于多维输入,默认情况下,叉积沿最后一个轴计算。这是一个批处理的三维叉积,对输入的行进行操作

>>> a = jnp.array([[1, 2, 3],
...                [3, 4, 3]])
>>> b = jnp.array([[2, 3, 2],
...                [4, 5, 6]])
>>> jnp.cross(a, b)
Array([[-5,  4, -1],
       [ 9, -6, -1]], dtype=int32)

指定 axis=0 将其变为批处理的二维叉积,对输入的列进行操作

>>> jnp.cross(a, b, axis=0)
Array([-2, -2, 12], dtype=int32)

等效地,我们可以独立地指定输入 ab 以及输出 c 的轴

>>> jnp.cross(a, b, axisa=0, axisb=0, axisc=0)
Array([-2, -2, 12], dtype=int32)