jax.numpy.mask_indices

jax.numpy.mask_indices#

jax.numpy.mask_indices(*args, **kwargs)[source]#

给定掩码函数,返回访问 (n, n) 数组的索引。

numpy.mask_indices() 的 LAX 后端实现。

原始文档字符串如下。

假设 mask_func 是一个函数,对于大小为 (n, n) 的方形数组 a,可能带有一个偏移参数 k,当调用 mask_func(a, k) 时,返回一个新的数组,其中某些位置为零(像 triutril 这样的函数正是这样做的)。然后,此函数返回非零值所在位置的索引。

参数:
  • n (int) – 返回的索引将有效地访问形状为 (n, n) 的数组。

  • mask_func (callable) – 一个函数,其调用签名类似于 triutril。也就是说,mask_func(x, k) 返回一个布尔数组,形状与 x 相同。k 是函数的可选参数。

  • k (scalar) – 一个可选参数,它被传递给 mask_func。像 triutril 这样的函数接受第二个参数,该参数被解释为偏移量。

返回值:

indices – 对应于 mask_func(np.ones((n, n)), k) 为 True 的位置的 n 个索引数组。

返回类型:

tuple of arrays.