jax.numpy.mask_indices#
- jax.numpy.mask_indices(*args, **kwargs)[source]#
给定掩码函数,返回访问 (n, n) 数组的索引。
numpy.mask_indices()
的 LAX 后端实现。原始文档字符串如下。
假设 mask_func 是一个函数,对于大小为
(n, n)
的方形数组 a,可能带有一个偏移参数 k,当调用mask_func(a, k)
时,返回一个新的数组,其中某些位置为零(像 triu 或 tril 这样的函数正是这样做的)。然后,此函数返回非零值所在位置的索引。- 参数:
n (int) – 返回的索引将有效地访问形状为 (n, n) 的数组。
mask_func (callable) – 一个函数,其调用签名类似于 triu、tril。也就是说,
mask_func(x, k)
返回一个布尔数组,形状与 x 相同。k 是函数的可选参数。k (scalar) – 一个可选参数,它被传递给 mask_func。像 triu、tril 这样的函数接受第二个参数,该参数被解释为偏移量。
- 返回值:
indices – 对应于
mask_func(np.ones((n, n)), k)
为 True 的位置的 n 个索引数组。- 返回类型:
tuple of arrays.