jax.numpy.take_along_axis

jax.numpy.take_along_axis#

jax.numpy.take_along_axis(arr, indices, axis, mode=None, fill_value=None)[source]#

从数组中获取元素。

JAX 实现 numpy.take_along_axis(),用 jax.lax.gather() 实现。JAX 的行为在超出边界索引的情况下与 NumPy 不同;请参阅下面的 mode 参数。

参数:
  • a – 要从中获取值的数组。

  • indices (ArrayLike) – 整数索引数组。如果 axisNone,则必须是一维的。如果 axis 不为 None,则必须有 a.ndim == indices.ndim,并且 a 必须与 indices 在除 axis 之外的维度上广播兼容。

  • (int | None) – 要获取值的轴。如果未指定,则数组将在应用索引之前被扁平化。

  • 模式 (str | lax.GatherScatterMode | None) – 越界索引模式,要么是 "fill" 要么是 "clip"。默认的 mode="fill" 会为越界索引返回无效值(例如 NaN)。有关 mode 选项的更多讨论,请参见 jax.numpy.ndarray.at

  • arr (ArrayLike)

  • 填充值 (StaticScalar | None)

返回:

a 中提取的值数组。

返回类型:

数组

另请参见

示例

>>> x = jnp.array([[1., 2., 3.],
...                [4., 5., 6.]])
>>> indices = jnp.array([[0, 2],
...                      [1, 0]])
>>> jnp.take_along_axis(x, indices, axis=1)
Array([[1., 3.],
       [5., 4.]], dtype=float32)
>>> x[jnp.arange(2)[:, None], indices]  # equivalent via indexing syntax
Array([[1., 3.],
       [5., 4.]], dtype=float32)

越界索引将用无效值填充。对于浮点型输入,这是 NaN

>>> indices = jnp.array([[1, 0, 2]])
>>> jnp.take_along_axis(x, indices, axis=0)
Array([[ 4.,  2., nan]], dtype=float32)
>>> x.at[indices, jnp.arange(3)].get(
...     mode='fill', fill_value=jnp.nan)  # equivalent via indexing syntax
Array([[ 4.,  2., nan]], dtype=float32)

take_along_axis 有助于从多维 argsorts 和 arg 缩减中提取值。例如,这里我们计算 argsort() 沿轴的索引,并使用 take_along_axis 来构造排序后的数组

>>> x = jnp.array([[5, 3, 4],
...                [2, 7, 6]])
>>> indices = jnp.argsort(x, axis=1)
>>> indices
Array([[1, 2, 0],
       [0, 2, 1]], dtype=int32)
>>> jnp.take_along_axis(x, indices, axis=1)
Array([[3, 4, 5],
       [2, 6, 7]], dtype=int32)

类似地,我们可以使用 argmin() 以及 keepdims=True 并使用 take_along_axis 来提取最小值

>>> idx = jnp.argmin(x, axis=1, keepdims=True)
>>> idx
Array([[1],
       [0]], dtype=int32)
>>> jnp.take_along_axis(x, idx, axis=1)
Array([[3],
       [2]], dtype=int32)