jax.numpy.broadcast_shapes

jax.numpy.broadcast_shapes#

jax.numpy.broadcast_shapes(*shapes: Sequence[int]) tuple[int, ...][source]#
jax.numpy.broadcast_shapes(*shapes: Sequence[int | core.Tracer]) tuple[int | core.Tracer, ...]

将输入形状广播到一个共同的输出形状。

JAX 实现 numpy.broadcast_shapes()。JAX 使用 NumPy 风格的广播规则,您可以在 NumPy 广播 中了解更多信息。

参数:

shapes – 0 个或多个形状,指定为整数序列

返回值:

广播后的形状,表示为整数元组。

另请参阅

示例

一些兼容的形状

>>> jnp.broadcast_shapes((1,), (4,))
(4,)
>>> jnp.broadcast_shapes((3, 1), (4,))
(3, 4)
>>> jnp.broadcast_shapes((3, 1), (1, 4), (5, 1, 1))
(5, 3, 4)

不兼容的形状

>>> jnp.broadcast_shapes((3, 1), (4, 1))  
Traceback (most recent call last):
ValueError: Incompatible shapes for broadcasting: shapes=[(3, 1), (4, 1)]