jax.numpy.put_along_axis#
- jax.numpy.put_along_axis(arr, indices, values, axis, inplace=True, *, mode=None)[源代码]#
通过匹配 1d 索引和数据切片,将值放入目标数组。
numpy.put_along_axis()
的 JAX 实现。numpy.put_along_axis()
的语义是就地修改数组,这对于 JAX 的不可变数组是不可能的。JAX 版本返回输入的修改后的副本,并添加了inplace
参数,用户必须将其设置为 False`,以提醒此 API 差异。- 参数:
arr (ArrayLike) – 将放入值的数组。
indices (ArrayLike) – 用于放置值的索引数组。
values (ArrayLike) – 要放入数组的值的数组。
axis ( int | None) – 用于放置值的轴。如果未指定,则在应用索引之前,数组将被展平。
inplace ( bool ) – 必须设置为 False,以表明输入不会被原地修改,而是返回一个修改后的副本。
mode ( str | None) – 超出边界的索引模式。有关
mode
选项的更多讨论,请参阅jax.numpy.ndarray.at
。
- 返回:
一个
a
的副本,其中指定的条目已更新。- 返回类型:
另请参阅
jax.numpy.put()
: 将元素放入给定索引的数组中。jax.numpy.place()
: 通过布尔掩码将元素放入数组中。jax.numpy.ndarray.at()
: 使用 NumPy 样式的索引更新数组。jax.numpy.take()
: 从给定索引的数组中提取值。jax.numpy.take_along_axis()
: 沿轴从数组中提取值。
示例
>>> from jax import numpy as jnp >>> a = jnp.array([[10, 30, 20], [60, 40, 50]]) >>> i = jnp.argmax(a, axis=1, keepdims=True) >>> print(i) [[1] [0]] >>> b = jnp.put_along_axis(a, i, 99, axis=1, inplace=False) >>> print(b) [[10 99 20] [99 40 50]]