jax.numpy.broadcast_arrays

jax.numpy.broadcast_arrays#

jax.numpy.broadcast_arrays(*args)[source]#

将数组广播到共同的形状。

JAX 对 numpy.broadcast_arrays() 的实现。JAX 使用 NumPy 式的广播规则,您可以在 NumPy 广播 中了解更多信息。

参数:

args (ArrayLike) – 要广播的一个或多个类数组对象。

返回值:

包含广播后的输入副本的数组列表。

返回类型:

list[Array]

另请参阅

示例

>>> x = jnp.arange(3)
>>> y = jnp.int32(1)
>>> jnp.broadcast_arrays(x, y)
[Array([0, 1, 2], dtype=int32), Array([1, 1, 1], dtype=int32)]
>>> x = jnp.array([[1, 2, 3]])
>>> y = jnp.array([[10],
...                [20]])
>>> x2, y2 = jnp.broadcast_arrays(x, y)
>>> x2
Array([[1, 2, 3],
       [1, 2, 3]], dtype=int32)
>>> y2
Array([[10, 10, 10],
       [20, 20, 20]], dtype=int32)