jax.numpy.broadcast_arrays#

jax.numpy.broadcast_arrays(*args)[源代码]#

将数组广播到公共形状。

numpy.broadcast_arrays() 的 JAX 实现。JAX 使用 NumPy 风格的广播规则,你可以在 NumPy 广播 中阅读更多相关内容。

参数:

args (ArrayLike) – 零个或多个要广播的类数组对象。

返回:

包含输入广播副本的数组列表。

返回类型:

list[Array]

另请参阅

示例

>>> x = jnp.arange(3)
>>> y = jnp.int32(1)
>>> jnp.broadcast_arrays(x, y)
[Array([0, 1, 2], dtype=int32), Array([1, 1, 1], dtype=int32)]
>>> x = jnp.array([[1, 2, 3]])
>>> y = jnp.array([[10],
...                [20]])
>>> x2, y2 = jnp.broadcast_arrays(x, y)
>>> x2
Array([[1, 2, 3],
       [1, 2, 3]], dtype=int32)
>>> y2
Array([[10, 10, 10],
       [20, 20, 20]], dtype=int32)