jax.scipy.linalg.block_diag#
- jax.scipy.linalg.block_diag(*arrs)[source]#
从输入数组创建块对角矩阵。
JAX 实现的
scipy.linalg.block_diag()
.- 参数:
*arrs (ArrayLike) – 最多两维的数组
- 返回值:
通过将输入数组沿对角线放置而构造的 2D 块对角数组。
- 返回类型:
示例
>>> A = jnp.ones((1, 1)) >>> B = jnp.ones((2, 2)) >>> C = jnp.ones((3, 3)) >>> jax.scipy.linalg.block_diag(A, B, C) Array([[1., 0., 0., 0., 0., 0.], [0., 1., 1., 0., 0., 0.], [0., 1., 1., 0., 0., 0.], [0., 0., 0., 1., 1., 1.], [0., 0., 0., 1., 1., 1.], [0., 0., 0., 1., 1., 1.]], dtype=float32)