jax.numpy.take_along_axis#

jax.numpy.take_along_axis(arr, indices, axis, mode=None, fill_value=None)[源代码]#

从数组中提取元素。

JAX 对 numpy.take_along_axis() 的实现,使用 jax.lax.gather() 实现。在越界索引的情况下,JAX 的行为与 NumPy 不同;请参阅下面的 mode 参数。

参数:
  • a – 要从中提取值的数组。

  • indices (ArrayLike) – 整数索引的数组。如果 axisNone,则必须是一维的。如果 axis 不为 None,则必须有 a.ndim == indices.ndim,并且 a 必须与 indices 在除 axis 之外的维度上进行广播兼容。

  • axis (int | None) – 沿其提取值的轴。如果未指定,则在应用索引之前将展平数组。

  • mode (str | lax.GatherScatterMode | None) – 越界索引模式,可以是 "fill""clip"。默认的 mode="fill" 对于越界索引返回无效值(例如 NaN)。有关 mode 选项的更多讨论,请参阅 jax.numpy.ndarray.at

  • arr (ArrayLike)

  • fill_value (StaticScalar | None)

返回:

a 中提取的值的数组。

返回类型:

数组

另请参阅

示例

>>> x = jnp.array([[1., 2., 3.],
...                [4., 5., 6.]])
>>> indices = jnp.array([[0, 2],
...                      [1, 0]])
>>> jnp.take_along_axis(x, indices, axis=1)
Array([[1., 3.],
       [5., 4.]], dtype=float32)
>>> x[jnp.arange(2)[:, None], indices]  # equivalent via indexing syntax
Array([[1., 3.],
       [5., 4.]], dtype=float32)

越界索引填充无效值。对于浮点输入,这是 NaN

>>> indices = jnp.array([[1, 0, 2]])
>>> jnp.take_along_axis(x, indices, axis=0)
Array([[ 4.,  2., nan]], dtype=float32)
>>> x.at[indices, jnp.arange(3)].get(
...     mode='fill', fill_value=jnp.nan)  # equivalent via indexing syntax
Array([[ 4.,  2., nan]], dtype=float32)

take_along_axis 有助于从多维 argsort 和参数规约中提取值。在这里,我们计算沿轴的 argsort() 索引,并使用 take_along_axis 构建排序后的数组

>>> x = jnp.array([[5, 3, 4],
...                [2, 7, 6]])
>>> indices = jnp.argsort(x, axis=1)
>>> indices
Array([[1, 2, 0],
       [0, 2, 1]], dtype=int32)
>>> jnp.take_along_axis(x, indices, axis=1)
Array([[3, 4, 5],
       [2, 6, 7]], dtype=int32)

类似地,我们可以使用 argmin()keepdims=True,并使用 take_along_axis 提取最小值

>>> idx = jnp.argmin(x, axis=1, keepdims=True)
>>> idx
Array([[1],
       [0]], dtype=int32)
>>> jnp.take_along_axis(x, idx, axis=1)
Array([[3],
       [2]], dtype=int32)