jax.scipy.linalg.block_diag

内容

jax.scipy.linalg.block_diag#

jax.scipy.linalg.block_diag(*arrs)[source]#

从输入数组创建块对角矩阵。

JAX 实现的 scipy.linalg.block_diag().

参数:

*arrs (ArrayLike) – 最多两维的数组

返回值:

通过将输入数组沿对角线放置而构造的 2D 块对角数组。

返回类型:

Array

示例

>>> A = jnp.ones((1, 1))
>>> B = jnp.ones((2, 2))
>>> C = jnp.ones((3, 3))
>>> jax.scipy.linalg.block_diag(A, B, C)
Array([[1., 0., 0., 0., 0., 0.],
       [0., 1., 1., 0., 0., 0.],
       [0., 1., 1., 0., 0., 0.],
       [0., 0., 0., 1., 1., 1.],
       [0., 0., 0., 1., 1., 1.],
       [0., 0., 0., 1., 1., 1.]], dtype=float32)