jax.numpy.mask_indices#

jax.numpy.mask_indices(n, mask_func, k=0, *, size=None)[源代码]#

返回一个 (n, n) 数组的掩码的索引。

参数:
  • n (int) – 静态整数数组维度。

  • mask_func (Callable[[ArrayLike, int], Array]) – 一个函数,它接受一个形状为 (n, n) 的数组和一个可选的偏移量 k,并返回一个形状为 (n, n) 的掩码。具有此签名的函数示例包括 triu()tril()

  • k (int) – 传递给 mask_func 的标量值。

  • size (int | None | None) – 可选参数,指定输出数组的静态大小。这在从掩码生成索引时传递给 nonzero()

返回:

一个索引元组,其中 mask_func 非零。

返回类型:

tuple[Array, Array]

另请参阅

示例

对内置掩码函数调用 mask_indices

>>> jnp.mask_indices(3, jnp.triu)
(Array([0, 0, 0, 1, 1, 2], dtype=int32), Array([0, 1, 2, 1, 2, 2], dtype=int32))
>>> jnp.mask_indices(3, jnp.tril)
(Array([0, 1, 1, 2, 2, 2], dtype=int32), Array([0, 0, 1, 0, 1, 2], dtype=int32))

对自定义掩码函数调用 mask_indices

>>> def mask_func(x, k=0):
...   i = jnp.arange(x.shape[0])[:, None]
...   j = jnp.arange(x.shape[1])
...   return (i + 1) % (j + 1 + k) == 0
>>> mask_func(jnp.ones((3, 3)))
Array([[ True, False, False],
       [ True,  True, False],
       [ True, False,  True]], dtype=bool)
>>> jnp.mask_indices(3, mask_func)
(Array([0, 1, 1, 2, 2], dtype=int32), Array([0, 0, 1, 0, 2], dtype=int32))