jax.numpy.tril#
- jax.numpy.tril(m, k=0)[source]#
返回数组的下三角。
numpy.tril()
的 JAX 实现- 参数:
m (ArrayLike) – 输入数组。必须有
m.ndim >= 2
。k (int) – k:可选,int,默认为 0。指定数组中元素被设置为零的副对角线上方。
k=0
指的是主对角线,k<0
指的是主对角线以下的副对角线,而k>0
指的是主对角线以上的副对角线。
- 返回值:
一个与输入数组形状相同的数组,包含给定数组的上三角,其中由
k
指定的副对角线下方的元素被设置为零。- 返回类型:
另请参阅
jax.numpy.triu()
:返回数组的上三角。jax.numpy.tri()
:返回一个数组,其中对角线及其以下为 1,其他位置为 0。
示例
>>> x = jnp.array([[1, 2, 3, 4], ... [5, 6, 7, 8], ... [9, 10, 11, 12]]) >>> jnp.tril(x) Array([[ 1, 0, 0, 0], [ 5, 6, 0, 0], [ 9, 10, 11, 0]], dtype=int32) >>> jnp.tril(x, k=1) Array([[ 1, 2, 0, 0], [ 5, 6, 7, 0], [ 9, 10, 11, 12]], dtype=int32) >>> jnp.tril(x, k=-1) Array([[ 0, 0, 0, 0], [ 5, 0, 0, 0], [ 9, 10, 0, 0]], dtype=int32)
当
m.ndim > 2
时,jnp.tril
会对尾随轴进行批处理操作。>>> x1 = jnp.array([[[1, 2], ... [3, 4]], ... [[5, 6], ... [7, 8]]]) >>> jnp.tril(x1) Array([[[1, 0], [3, 4]], [[5, 0], [7, 8]]], dtype=int32)