jax.numpy.block#
- jax.numpy.block(arrays)[源代码]#
从块列表创建数组。
JAX 对
numpy.block()
的实现。- 参数:
arrays (ArrayLike | list[ArrayLike]) – 一个数组,或数组的嵌套列表,它们将被连接在一起以形成最终的数组。
- 返回:
由输入构建的单个数组。
- 返回类型:
示例
考虑这些块
>>> zeros = jnp.zeros((2, 2)) >>> ones = jnp.ones((2, 2)) >>> twos = jnp.full((2, 2), 2) >>> threes = jnp.full((2, 2), 3)
将单个数组传递给
block()
返回该数组>>> jnp.block(zeros) Array([[0., 0.], [0., 0.]], dtype=float32)
传递一个简单的数组列表会沿着最后一个轴连接它们
>>> jnp.block([zeros, ones]) Array([[0., 0., 1., 1.], [0., 0., 1., 1.]], dtype=float32)
传递一个双层嵌套的数组列表会沿着最后一个轴连接内部列表,并沿着倒数第二个轴连接外部列表
>>> jnp.block([[zeros, ones], ... [twos, threes]]) Array([[0., 0., 1., 1.], [0., 0., 1., 1.], [2., 2., 3., 3.], [2., 2., 3., 3.]], dtype=float32)
请注意,块不需要在所有维度上对齐,但沿着连接轴的大小必须匹配。例如,这是有效的,因为在内部的水平连接之后,生成的块具有用于外部垂直连接的有效形状。
>>> a = jnp.zeros((2, 1)) >>> b = jnp.ones((2, 3)) >>> c = jnp.full((1, 2), 2) >>> d = jnp.full((1, 2), 3) >>> jnp.block([[a, b], [c, d]]) Array([[0., 1., 1., 1.], [0., 1., 1., 1.], [2., 2., 3., 3.]], dtype=float32)
还要注意,这个逻辑可以推广到 3 个或更多维度的块。这是一个 3 维的分块数组
>>> x = jnp.arange(6).reshape((1, 2, 3)) >>> blocks = [[[x for i in range(3)] for j in range(4)] for k in range(5)] >>> jnp.block(blocks).shape (5, 8, 9)