jax.numpy.outer

内容

jax.numpy.outer#

jax.numpy.outer(a, b, out=None)[source]#

计算两个数组的外积。

JAX 实现 numpy.outer().

参数:
  • a (ArrayLike) – 第一个输入数组,如果它不是一维,则将被扁平化。

  • b (ArrayLike) – 第二个输入数组,如果它不是一维,则将被扁平化。

  • out (None) – JAX 不支持。

返回值:

输入 ab 的外积。返回的数组的形状将为 (a.size, b.size).

返回值类型:

数组

另请参阅

示例

>>> a = jnp.array([1, 2, 3])
>>> b = jnp.array([4, 5, 6])
>>> jnp.outer(a, b)
Array([[ 4,  5,  6],
       [ 8, 10, 12],
       [12, 15, 18]], dtype=int32)