jax.numpy.diag

内容

jax.numpy.diag#

jax.numpy.diag(v, k=0)[source]#

返回指定的对角线或构造一个对角线数组。

JAX 实现 numpy.diag().

JAX 版本始终返回输入的副本,尽管如果这在 JIT 编译中使用,编译器可能会避免复制。

参数:
  • v (ArrayLike) – 输入数组。可以是 1 维数组以创建对角矩阵,也可以是 2 维数组以提取对角线。

  • k (int) – 可选,默认为 0。对角线偏移量。正值将对角线放置在主对角线上方,负值将对角线放置在主对角线下方。

返回:

如果 v 是一个 2 维数组,则包含对角线元素的 1 维数组。如果 v 是一个 1 维数组,则是一个 2 维数组,其中输入元素沿指定对角线放置。

返回类型:

数组

示例

从 1 维数组创建对角矩阵

>>> jnp.diag(jnp.array([1, 2, 3]))
Array([[1, 0, 0],
       [0, 2, 0],
       [0, 0, 3]], dtype=int32)

指定对角线偏移量

>>> jnp.diag(jnp.array([1, 2, 3]), k=1)
Array([[0, 1, 0, 0],
       [0, 0, 2, 0],
       [0, 0, 0, 3],
       [0, 0, 0, 0]], dtype=int32)

从 2 维数组提取对角线

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6],
...                [7, 8, 9]])
>>> jnp.diag(x)
Array([1, 5, 9], dtype=int32)