jax.numpy.choose#

jax.numpy.choose(a, choices, out=None, mode='raise')[源代码]#

通过堆叠选择数组的切片来构造一个数组。

JAX 实现的 numpy.choose()

此函数的语义可能令人困惑,但在最简单的情况下,当 a 是一个一维数组, choices 是一个二维数组,并且 a 的所有条目都在范围内(即 0 <= a_i < len(choices))时,该函数等效于以下内容

def choose(a, choices):
  return jnp.array([choices[a_i, i] for i, a_i in enumerate(a)])

在更一般的情况下, a 可能具有任意数量的维度,而 choices 可能是广播兼容的数组的任意序列。在这种情况下,同样对于边界内的索引,逻辑等效于

def choose(a, choices):
  a, *choices = jnp.broadcast_arrays(a, *choices)
  choices = jnp.array(choices)
  return jnp.array([choices[a[idx], *idx] for idx in np.ndindex(a.shape)])

唯一的额外复杂性来自 mode 参数,它控制 a 中超出边界的索引的行为,如下所述。

参数:
  • a (ArrayLike) – 一个整数索引的 N 维数组。

  • choices (Array | np.ndarray | Sequence[ArrayLike]) – 一个数组或数组序列。序列中的所有数组必须与 a 相互广播兼容。

  • out (None | None) – JAX 未使用

  • mode (str) – 指定越界索引模式; 'raise'(默认), 'wrap''clip' 之一。请注意, 'raise' 的默认模式与 JAX 转换不兼容。

返回:

一个数组,其中包含来自 choices 的堆叠切片,其索引由 a 指定。结果的形状为 broadcast_shapes(a.shape, *(c.shape for c in choices))

返回类型:

Array

另请参阅

示例

这是具有二维选择数组的一维索引数组的最简单情况,在这种情况下,它从每列中选择索引值

>>> choices = jnp.array([[ 1,  2,  3,  4],
...                      [ 5,  6,  7,  8],
...                      [ 9, 10, 11, 12]])
>>> a = jnp.array([2, 0, 1, 0])
>>> jnp.choose(a, choices)
Array([9, 2, 7, 4], dtype=int32)

mode 参数指定如何处理越界索引;选项是 wrapclip

>>> a2 = jnp.array([2, 0, 1, 4])  # last index out-of-bound
>>> jnp.choose(a2, choices, mode='clip')
Array([ 9,  2,  7, 12], dtype=int32)
>>> jnp.choose(a2, choices, mode='wrap')
Array([9, 2, 7, 8], dtype=int32)

在更一般的情况下, choices 可以是具有任何广播兼容形状的类数组对象的序列。

>>> choice_1 = jnp.array([1, 2, 3, 4])
>>> choice_2 = 99
>>> choice_3 = jnp.array([[10],
...                       [20],
...                       [30]])
>>> a = jnp.array([[0, 1, 2, 0],
...                [1, 2, 0, 1],
...                [2, 0, 1, 2]])
>>> jnp.choose(a, [choice_1, choice_2, choice_3], mode='wrap')
Array([[ 1, 99, 10,  4],
       [99, 20,  3, 99],
       [30,  2, 99, 30]], dtype=int32)