jax.numpy.fill_diagonal

jax.numpy.fill_diagonal#

jax.numpy.fill_diagonal(a, val, wrap=False, *, inplace=True)[source]#

返回一个数组的副本,其中对角线被覆盖。

JAX 实现 numpy.fill_diagonal()

numpy.fill_diagonal() 的语义是就地修改数组,这对于 JAX 的不可变数组来说是不可能的。JAX 版本返回输入的修改副本,并添加了 inplace 参数,用户必须将其设置为 False` 以提醒此 API 的差异。

参数:
  • a (类数组) – 输入数组。必须满足 a.ndim >= 2。如果 a.ndim >= 3,则所有维度的大小必须相同。

  • val (类数组) – 用来填充对角线的标量或数组。如果为数组,它将被展平并重复以填充对角线条目。

  • inplace (布尔值) – 必须设置为 False 以指示输入不会被就地修改,而是返回一个修改后的副本。

  • wrap (布尔值)

返回值:

一个 a 的副本,其对角线设置为 val

返回类型:

数组

示例

>>> x = jnp.zeros((3, 3), dtype=int)
>>> jnp.fill_diagonal(x, jnp.array([1, 2, 3]), inplace=False)
Array([[1, 0, 0],
       [0, 2, 0],
       [0, 0, 3]], dtype=int32)

numpy.fill_diagonal() 不同,输入 x 不会被修改。

如果对角线值条目过多,它将被截断

>>> jnp.fill_diagonal(x, jnp.arange(100, 200), inplace=False)
Array([[100,   0,   0],
       [  0, 101,   0],
       [  0,   0, 102]], dtype=int32)

如果对角线条目过少,它将被重复

>>> x = jnp.zeros((4, 4), dtype=int)
>>> jnp.fill_diagonal(x, jnp.array([3, 4]), inplace=False)
Array([[3, 0, 0, 0],
       [0, 4, 0, 0],
       [0, 0, 3, 0],
       [0, 0, 0, 4]], dtype=int32)

对于非方阵,填充前导方阵切片的对角线

>>> x = jnp.zeros((3, 5), dtype=int)
>>> jnp.fill_diagonal(x, 1, inplace=False)
Array([[1, 0, 0, 0, 0],
       [0, 1, 0, 0, 0],
       [0, 0, 1, 0, 0]], dtype=int32)

对于方 N 维数组,填充 N 维对角线

>>> y = jnp.zeros((2, 2, 2))
>>> jnp.fill_diagonal(y, 1, inplace=False)
Array([[[1., 0.],
        [0., 0.]],

       [[0., 0.],
        [0., 1.]]], dtype=float32)