jax.numpy.triu

内容

jax.numpy.triu#

jax.numpy.triu(m, k=0)[source]#

返回数组的上三角。

JAX 实现 numpy.triu()

参数:
  • m (ArrayLike) – 输入数组。必须有 m.ndim >= 2

  • k (int) – 可选,int,默认值为 0。指定低于该子对角线的数组元素将被设置为零。 k=0 指主对角线, k<0 指主对角线以下的子对角线,而 k>0 指主对角线以上的子对角线。

返回值:

返回一个与输入形状相同的数组,其中给定数组的下三角矩阵元素(位于由 k 指定的次对角线以上)被设置为零。

返回类型:

数组

参见

示例

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6],
...                [7, 8, 9],
...                [10, 11, 12]])
>>> jnp.triu(x)
Array([[1, 2, 3],
       [0, 5, 6],
       [0, 0, 9],
       [0, 0, 0]], dtype=int32)
>>> jnp.triu(x, k=1)
Array([[0, 2, 3],
       [0, 0, 6],
       [0, 0, 0],
       [0, 0, 0]], dtype=int32)
>>> jnp.triu(x, k=-1)
Array([[ 1,  2,  3],
       [ 4,  5,  6],
       [ 0,  8,  9],
       [ 0,  0, 12]], dtype=int32)

m.ndim > 2 时,jnp.triu 对尾随轴进行批处理。

>>> x1 = jnp.array([[[1, 2],
...                  [3, 4]],
...                 [[5, 6],
...                  [7, 8]]])
>>> jnp.triu(x1)
Array([[[1, 2],
        [0, 4]],

       [[5, 6],
        [0, 8]]], dtype=int32)