jax.numpy.tri#

jax.numpy.tri(N, M=None, k=0, dtype=None)[源代码]#

返回一个数组,对角线及其下方的值为 1,其他位置为 0。

JAX 对 numpy.tri() 的实现

参数:
  • N (int) – int。返回数组的行维度。

  • M (int | None | None) – 可选,int。返回数组的列维度。如果未指定,则 M = N

  • k (int) – 可选,int,默认值为 0。指定数组中填充 1 的子对角线及其下方的位置。k=0 表示主对角线,k<0 表示主对角线以下的子对角线,k>0 表示主对角线以上的子对角线。

  • dtype (DTypeLike | None | None) – 可选,返回数组的数据类型。默认类型为 float。

返回值:

形状为 (N, M) 的数组,其中由 k 指定的子对角线以下(包括子对角线)的元素设置为 1,其他位置设置为 0。

返回类型:

数组

参见

示例

>>> jnp.tri(3)
Array([[1., 0., 0.],
       [1., 1., 0.],
       [1., 1., 1.]], dtype=float32)

M 不等于 N

>>> jnp.tri(3, 4)
Array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.]], dtype=float32)

k>0

>>> jnp.tri(3, k=1)
Array([[1., 1., 0.],
       [1., 1., 1.],
       [1., 1., 1.]], dtype=float32)

k<0

>>> jnp.tri(3, 4, k=-1)
Array([[0., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 1., 0., 0.]], dtype=float32)