jax.numpy.choose#
- jax.numpy.choose(a, choices, out=None, mode='raise')[源代码]#
通过堆叠选择数组的切片来构造一个数组。
JAX 实现的
numpy.choose()
。此函数的语义可能令人困惑,但在最简单的情况下,当
a
是一个一维数组,choices
是一个二维数组,并且a
的所有条目都在范围内(即0 <= a_i < len(choices)
)时,该函数等效于以下内容def choose(a, choices): return jnp.array([choices[a_i, i] for i, a_i in enumerate(a)])
在更一般的情况下,
a
可能具有任意数量的维度,而choices
可能是广播兼容的数组的任意序列。在这种情况下,同样对于边界内的索引,逻辑等效于def choose(a, choices): a, *choices = jnp.broadcast_arrays(a, *choices) choices = jnp.array(choices) return jnp.array([choices[a[idx], *idx] for idx in np.ndindex(a.shape)])
唯一的额外复杂性来自
mode
参数,它控制a
中超出边界的索引的行为,如下所述。- 参数:
- 返回:
一个数组,其中包含来自
choices
的堆叠切片,其索引由a
指定。结果的形状为broadcast_shapes(a.shape, *(c.shape for c in choices))
。- 返回类型:
另请参阅
jax.lax.switch()
:根据索引在 N 个函数之间进行选择。
示例
这是具有二维选择数组的一维索引数组的最简单情况,在这种情况下,它从每列中选择索引值
>>> choices = jnp.array([[ 1, 2, 3, 4], ... [ 5, 6, 7, 8], ... [ 9, 10, 11, 12]]) >>> a = jnp.array([2, 0, 1, 0]) >>> jnp.choose(a, choices) Array([9, 2, 7, 4], dtype=int32)
mode
参数指定如何处理越界索引;选项是wrap
或clip
>>> a2 = jnp.array([2, 0, 1, 4]) # last index out-of-bound >>> jnp.choose(a2, choices, mode='clip') Array([ 9, 2, 7, 12], dtype=int32) >>> jnp.choose(a2, choices, mode='wrap') Array([9, 2, 7, 8], dtype=int32)
在更一般的情况下,
choices
可以是具有任何广播兼容形状的类数组对象的序列。>>> choice_1 = jnp.array([1, 2, 3, 4]) >>> choice_2 = 99 >>> choice_3 = jnp.array([[10], ... [20], ... [30]]) >>> a = jnp.array([[0, 1, 2, 0], ... [1, 2, 0, 1], ... [2, 0, 1, 2]]) >>> jnp.choose(a, [choice_1, choice_2, choice_3], mode='wrap') Array([[ 1, 99, 10, 4], [99, 20, 3, 99], [30, 2, 99, 30]], dtype=int32)