jax.numpy.block#
- jax.numpy.block(arrays)[源代码]#
从块列表创建一个数组。
JAX 对
numpy.block()
的实现。示例
考虑这些块
>>> zeros = jnp.zeros((2, 2)) >>> ones = jnp.ones((2, 2)) >>> twos = jnp.full((2, 2), 2) >>> threes = jnp.full((2, 2), 3)
将单个数组传递给
block()
返回该数组>>> jnp.block(zeros) Array([[0., 0.], [0., 0.]], dtype=float32)
传递一个简单的数组列表会沿着最后一个轴连接它们
>>> jnp.block([zeros, ones]) Array([[0., 0., 1., 1.], [0., 0., 1., 1.]], dtype=float32)
传递一个双层嵌套的数组列表会沿着最后一个轴连接内部列表,并沿着倒数第二个轴连接外部列表
>>> jnp.block([[zeros, ones], ... [twos, threes]]) Array([[0., 0., 1., 1.], [0., 0., 1., 1.], [2., 2., 3., 3.], [2., 2., 3., 3.]], dtype=float32)
请注意,块不必在所有维度上对齐,尽管沿着连接轴的大小必须匹配。例如,这是有效的,因为在内部的水平连接之后,生成的块具有用于外部垂直连接的有效形状。
>>> a = jnp.zeros((2, 1)) >>> b = jnp.ones((2, 3)) >>> c = jnp.full((1, 2), 2) >>> d = jnp.full((1, 2), 3) >>> jnp.block([[a, b], [c, d]]) Array([[0., 1., 1., 1.], [0., 1., 1., 1.], [2., 2., 3., 3.]], dtype=float32)
另请注意,此逻辑推广到 3 个或更多维度中的块。这是一个 3 维的块状数组
>>> x = jnp.arange(6).reshape((1, 2, 3)) >>> blocks = [[[x for i in range(3)] for j in range(4)] for k in range(5)] >>> jnp.block(blocks).shape (5, 8, 9)