jax.numpy.broadcast_shapes#
- jax.numpy.broadcast_shapes(*shapes: Sequence[int]) tuple[int, ...] [source]#
- jax.numpy.broadcast_shapes(*shapes: Sequence[int | core.Tracer]) tuple[int | core.Tracer, ...]
将输入形状广播到一个共同的输出形状。
JAX 实现
numpy.broadcast_shapes()
。JAX 使用 NumPy 风格的广播规则,您可以在 NumPy 广播 中了解更多信息。- 参数:
shapes – 0 个或多个形状,指定为整数序列
- 返回值:
广播后的形状,表示为整数元组。
另请参阅
jax.numpy.broadcast_arrays()
: 将数组广播到一个共同的形状。jax.numpy.broadcast_to()
: 将数组广播到指定的形状。
示例
一些兼容的形状
>>> jnp.broadcast_shapes((1,), (4,)) (4,) >>> jnp.broadcast_shapes((3, 1), (4,)) (3, 4) >>> jnp.broadcast_shapes((3, 1), (1, 4), (5, 1, 1)) (5, 3, 4)
不兼容的形状
>>> jnp.broadcast_shapes((3, 1), (4, 1)) Traceback (most recent call last): ValueError: Incompatible shapes for broadcasting: shapes=[(3, 1), (4, 1)]