jax.numpy.broadcast_shapes#

jax.numpy.broadcast_shapes(*shapes)[源代码]#

将输入形状广播到公共输出形状。

numpy.broadcast_shapes() 的 JAX 实现。JAX 使用 NumPy 风格的广播规则,您可以在 NumPy 广播 中阅读更多相关信息。

参数:

shapes – 0 个或多个形状,指定为整数序列

返回:

广播后的形状,以整数元组的形式返回。

另请参阅

示例

一些兼容的形状

>>> jnp.broadcast_shapes((1,), (4,))
(4,)
>>> jnp.broadcast_shapes((3, 1), (4,))
(3, 4)
>>> jnp.broadcast_shapes((3, 1), (1, 4), (5, 1, 1))
(5, 3, 4)

不兼容的形状

>>> jnp.broadcast_shapes((3, 1), (4, 1))  
Traceback (most recent call last):
ValueError: Incompatible shapes for broadcasting: shapes=[(3, 1), (4, 1)]