jax.numpy.broadcast_shapes#
- jax.numpy.broadcast_shapes(*shapes)[源代码]#
将输入形状广播到公共输出形状。
numpy.broadcast_shapes()
的 JAX 实现。JAX 使用 NumPy 风格的广播规则,您可以在 NumPy 广播 中阅读更多相关信息。- 参数:
shapes – 0 个或多个形状,指定为整数序列
- 返回:
广播后的形状,以整数元组的形式返回。
另请参阅
jax.numpy.broadcast_arrays()
:将数组广播到公共形状。jax.numpy.broadcast_to()
:将数组广播到指定的形状。
示例
一些兼容的形状
>>> jnp.broadcast_shapes((1,), (4,)) (4,) >>> jnp.broadcast_shapes((3, 1), (4,)) (3, 4) >>> jnp.broadcast_shapes((3, 1), (1, 4), (5, 1, 1)) (5, 3, 4)
不兼容的形状
>>> jnp.broadcast_shapes((3, 1), (4, 1)) Traceback (most recent call last): ValueError: Incompatible shapes for broadcasting: shapes=[(3, 1), (4, 1)]