jax.numpy.choose

内容

jax.numpy.choose#

jax.numpy.choose(a, choices, out=None, mode='raise')[source]#

从索引数组和一组要从中选择的数组构造一个数组。

LAX 后端实现 numpy.choose().

原始文档字符串如下。

首先,如果感到困惑或不确定,一定要查看示例——就其普遍性而言,此函数并不像从以下代码描述(ndi = numpy.lib.index_tricks)中看到的那样简单。

np.choose(a,c) == np.array([c[a[I]][I] for I in ndi.ndindex(a.shape)]).

但这忽略了一些细微差别。这是一个完全通用的总结

给定一个整数“索引”数组(a)和一个 n 个数组的序列(choices),a 和每个选择数组首先根据需要广播到具有公共形状的数组;称这些数组为 BaBchoices[i], i = 0,…,n-1,我们有必要 Ba.shape == Bchoices[i].shape 对于每个 i。然后,创建一个新的形状为 Ba.shape 的数组,如下所示

  • 如果 mode='raise'(默认值),那么首先,a(以及 Ba)的每个元素都必须在范围 [0, n-1] 内;现在,假设 i(在该范围内)是 Ba(j0, j1, ..., jm) 位置的值 - 那么新数组中相同位置的值就是 Bchoices[i] 中相同位置的值;

  • 如果 mode='wrap',则 a(以及 Ba)中的值可以是任何(有符号)整数;使用模运算将范围 [0, n-1] 外的整数映射回该范围;然后像上面一样构造新数组;

  • 如果 mode='clip',则 a(以及 Ba)中的值可以是任何(有符号)整数;负整数映射到 0;大于 n-1 的值映射到 n-1;然后像上面一样构造新数组。

参数:
  • a (int 数组) – 此数组必须包含 [0, n-1] 中的整数,其中 n 是选择的数量,除非 mode=wrapmode=clip,在这种情况下,任何整数都是允许的。

  • choices (数组序列) – 选择数组。 a 和所有选择都必须广播到相同的形状。如果 choices 本身是一个数组(不推荐),则其最外层维度(即对应于 choices.shape[0] 的维度)被视为定义“序列”。

  • mode ({'raise' (默认值), 'wrap', 'clip'}, 可选) –

    指定如何处理 [0, n-1] 外的索引

    • ’raise’ : 抛出异常

    • ’wrap’ : 值变为值模 n

    • ’clip’ : 值 < 0 映射到 0,值 > n-1 映射到 n-1

  • out (None | None)

返回值:

merged_array – 合并的结果。

返回类型:

数组