jax.numpy.put_along_axis#

jax.numpy.put_along_axis(arr, indices, values, axis, inplace=True, *, mode=None)[源代码]#

通过匹配 1d 索引和数据切片,将值放入目标数组。

numpy.put_along_axis() 的 JAX 实现。

numpy.put_along_axis() 的语义是就地修改数组,这对于 JAX 的不可变数组是不可能的。JAX 版本返回输入的修改后的副本,并添加了 inplace 参数,用户必须将其设置为 False`,以提醒此 API 差异。

参数:
  • arr (ArrayLike) – 将放入值的数组。

  • indices (ArrayLike) – 用于放置值的索引数组。

  • values (ArrayLike) – 要放入数组的值的数组。

  • axis ( int | None) – 用于放置值的轴。如果未指定,则在应用索引之前,数组将被展平。

  • inplace ( bool ) – 必须设置为 False,以表明输入不会被原地修改,而是返回一个修改后的副本。

  • mode ( str | None) – 超出边界的索引模式。有关 mode 选项的更多讨论,请参阅 jax.numpy.ndarray.at

返回:

一个 a 的副本,其中指定的条目已更新。

返回类型:

数组

另请参阅

示例

>>> from jax import numpy as jnp
>>> a = jnp.array([[10, 30, 20], [60, 40, 50]])
>>> i = jnp.argmax(a, axis=1, keepdims=True)
>>> print(i)
[[1]
 [0]]
>>> b = jnp.put_along_axis(a, i, 99, axis=1, inplace=False)
>>> print(b)
[[10 99 20]
 [99 40 50]]