jax.numpy.block

内容

jax.numpy.block#

jax.numpy.block(arrays)[source]#

从嵌套块列表中组装一个 nd 数组。

LAX 后端实现 numpy.block()

原始文档字符串如下。

最内层列表中的块沿最后一个维度 (-1) 连接(参见 concatenate),然后这些块沿倒数第二个维度 (-2) 连接,依此类推,直到到达最外层列表。

块可以是任何维度,但不会使用正常的规则进行广播。相反,将插入大小为 1 的前导轴,以使所有块的 block.ndim 相同。这主要用于处理标量,这意味着像 np.block([v, 1]) 这样的代码是有效的,其中 v.ndim == 1

当嵌套列表深度为两层时,这允许从其组件构建块矩阵。

在版本 1.13.0 中添加。

参数:

arrays (嵌套列表array_like标量 (但不是元组)) –

如果传递单个 ndarray 或标量(深度为 0 的嵌套列表),则会返回未修改的副本(不会复制)。

元素形状必须沿适当的轴匹配(无需广播),但根据需要会将前导 1 预先添加到形状以使维度匹配。

返回值:

block_array – 从给定块组装的数组。

输出的维度等于以下两者中最大的一个:

  • 所有输入的维度

  • 输入列表嵌套的深度

返回类型:

ndarray