jax.numpy.block#

jax.numpy.block(arrays)[源代码]#

从块列表创建一个数组。

JAX 对 numpy.block() 的实现。

参数:

arrays (类数组 | 列表[类数组]) – 一个数组,或者一个嵌套的数组列表,它们将被连接在一起形成最终数组。

返回:

一个由输入构建的单个数组。

返回类型:

数组

另请参阅

示例

考虑这些块

>>> zeros = jnp.zeros((2, 2))
>>> ones = jnp.ones((2, 2))
>>> twos = jnp.full((2, 2), 2)
>>> threes = jnp.full((2, 2), 3)

将单个数组传递给 block() 返回该数组

>>> jnp.block(zeros)
Array([[0., 0.],
       [0., 0.]], dtype=float32)

传递一个简单的数组列表会沿着最后一个轴连接它们

>>> jnp.block([zeros, ones])
Array([[0., 0., 1., 1.],
       [0., 0., 1., 1.]], dtype=float32)

传递一个双层嵌套的数组列表会沿着最后一个轴连接内部列表,并沿着倒数第二个轴连接外部列表

>>> jnp.block([[zeros, ones],
...            [twos, threes]])
Array([[0., 0., 1., 1.],
       [0., 0., 1., 1.],
       [2., 2., 3., 3.],
       [2., 2., 3., 3.]], dtype=float32)

请注意,块不必在所有维度上对齐,尽管沿着连接轴的大小必须匹配。例如,这是有效的,因为在内部的水平连接之后,生成的块具有用于外部垂直连接的有效形状。

>>> a = jnp.zeros((2, 1))
>>> b = jnp.ones((2, 3))
>>> c = jnp.full((1, 2), 2)
>>> d = jnp.full((1, 2), 3)
>>> jnp.block([[a, b], [c, d]])
Array([[0., 1., 1., 1.],
       [0., 1., 1., 1.],
       [2., 2., 3., 3.]], dtype=float32)

另请注意,此逻辑推广到 3 个或更多维度中的块。这是一个 3 维的块状数组

>>> x = jnp.arange(6).reshape((1, 2, 3))
>>> blocks = [[[x for i in range(3)] for j in range(4)] for k in range(5)]
>>> jnp.block(blocks).shape
(5, 8, 9)