jax.numpy.mask_indices#
- jax.numpy.mask_indices(n, mask_func, k=0, *, size=None)[源代码]#
返回一个 (n, n) 数组的掩码的索引。
- 参数:
- 返回:
一个索引元组,其中
mask_func
非零。- 返回类型:
另请参阅
jax.numpy.triu_indices()
: 计算triu()
的mask_indices
。jax.numpy.tril_indices()
: 计算tril()
的mask_indices
。
示例
对内置掩码函数调用
mask_indices
>>> jnp.mask_indices(3, jnp.triu) (Array([0, 0, 0, 1, 1, 2], dtype=int32), Array([0, 1, 2, 1, 2, 2], dtype=int32))
>>> jnp.mask_indices(3, jnp.tril) (Array([0, 1, 1, 2, 2, 2], dtype=int32), Array([0, 0, 1, 0, 1, 2], dtype=int32))
对自定义掩码函数调用
mask_indices
>>> def mask_func(x, k=0): ... i = jnp.arange(x.shape[0])[:, None] ... j = jnp.arange(x.shape[1]) ... return (i + 1) % (j + 1 + k) == 0 >>> mask_func(jnp.ones((3, 3))) Array([[ True, False, False], [ True, True, False], [ True, False, True]], dtype=bool) >>> jnp.mask_indices(3, mask_func) (Array([0, 1, 1, 2, 2], dtype=int32), Array([0, 0, 1, 0, 2], dtype=int32))