jax.numpy.cross#
- jax.numpy.cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None)[源代码]#
计算两个数组的(批处理)叉积。
numpy.cross()
的 JAX 实现。这将计算二维或三维的叉积,
\[c = a \times b\]在三维空间中,
c
是一个长度为 3 的数组。在二维空间中,c
是一个标量。- 参数:
a – N 维数组。
a.shape[axisa]
表示叉积的维度,必须为 2 或 3。b – N 维数组。必须有
b.shape[axisb] == a.shape[axisb]
,并且a
和b
的其他维度必须具有广播兼容性。axisa (int) – 指定
a
中计算叉积的轴。axisb (int) – 指定
b
中计算叉积的轴。axisc (int) – 指定
c
中存储叉积结果的轴。axis (int | None) – 如果指定,则使用单个值覆盖
axisa
、axisb
和axisc
。
- 返回:
数组
c
,包含a
和b
沿指定轴的(批处理)叉积。
另请参阅
jax.numpy.linalg.cross()
:一个用于计算三维向量叉积的数组 API 兼容函数。
示例
二维叉积返回一个标量
>>> a = jnp.array([1, 2]) >>> b = jnp.array([3, 4]) >>> jnp.cross(a, b) Array(-2, dtype=int32)
三维叉积返回一个长度为 3 的向量
>>> a = jnp.array([1, 2, 3]) >>> b = jnp.array([4, 5, 6]) >>> jnp.cross(a, b) Array([-3, 6, -3], dtype=int32)
对于多维输入,默认情况下,叉积沿最后一个轴计算。这是一个批处理的三维叉积,对输入的行进行操作
>>> a = jnp.array([[1, 2, 3], ... [3, 4, 3]]) >>> b = jnp.array([[2, 3, 2], ... [4, 5, 6]]) >>> jnp.cross(a, b) Array([[-5, 4, -1], [ 9, -6, -1]], dtype=int32)
指定 axis=0 将其变为批处理的二维叉积,对输入的列进行操作
>>> jnp.cross(a, b, axis=0) Array([-2, -2, 12], dtype=int32)
等效地,我们可以独立地指定输入
a
和b
以及输出c
的轴>>> jnp.cross(a, b, axisa=0, axisb=0, axisc=0) Array([-2, -2, 12], dtype=int32)