jax.Array.at

内容

jax.Array.at#

abstract property Array.at[source]#

索引更新功能的辅助属性。

The at property provides a functionally pure equivalent of in-place array modifications.

特别是

替代语法

等效的原位表达式

x = x.at[idx].set(y)

x[idx] = y

x = x.at[idx].add(y)

x[idx] += y

x = x.at[idx].multiply(y)

x[idx] *= y

x = x.at[idx].divide(y)

x[idx] /= y

x = x.at[idx].power(y)

x[idx] **= y

x = x.at[idx].min(y)

x[idx] = minimum(x[idx], y)

x = x.at[idx].max(y)

x[idx] = maximum(x[idx], y)

x = x.at[idx].apply(ufunc)

ufunc.at(x, idx)

x = x.at[idx].get()

x = x[idx]

x.at 中的表达式都不会修改原始的 x;而是返回 x 的修改副本。但是,在 jit() 编译的函数内部,像 x = x.at[idx].set(y) 这样的表达式保证会被就地应用。

与 NumPy 中的原地操作(如 x[idx] += y)不同,如果多个索引引用同一位置,则所有更新都将应用(NumPy 只会应用最后一次更新,而不是应用所有更新)。冲突更新应用的顺序是实现定义的,并且可能是不可确定的(例如,由于某些硬件平台上的并发)。

默认情况下,JAX 假设所有索引都在范围内。可以通过 mode 参数指定其他边界外索引语义(见下文)。

参数::
  • mode (str) –

    指定边界外索引模式。选项有

    • "promise_in_bounds": (默认)用户承诺索引在范围内。不会执行任何其他检查。在实践中,这意味着 get() 中的边界外索引将被剪切,而 set()add() 等中的边界外索引将被丢弃。

    • "clip": 将边界外索引夹紧到有效范围内。

    • "drop": 忽略边界外索引。

    • "fill": "drop" 的别名。对于 get(),可选的 fill_value 参数指定将返回的值。

      有关更多详细信息,请参见 jax.lax.GatherScatterMode

  • indices_are_sorted (bool) – 如果为 True,实现将假设传递给 at[] 的索引按升序排序,这可能会在某些后端导致更有效的执行。

  • unique_indices (bool) – 如果为 True,实现将假设传递给 at[] 的索引是唯一的,这可能会在某些后端导致更有效的执行。

  • fill_value (Any) – 仅适用于 get() 方法:当 mode'fill' 时,用于边界外切片返回的填充值。否则被忽略。对于非精确类型默认为 NaN,对于带符号类型默认为最大负值,对于无符号类型默认为最大正值,对于布尔类型默认为 True

示例

>>> x = jnp.arange(5.0)
>>> x
Array([0., 1., 2., 3., 4.], dtype=float32)
>>> x.at[2].add(10)
Array([ 0.,  1., 12.,  3.,  4.], dtype=float32)
>>> x.at[10].add(10)  # out-of-bounds indices are ignored
Array([0., 1., 2., 3., 4.], dtype=float32)
>>> x.at[20].add(10, mode='clip')
Array([ 0.,  1.,  2.,  3., 14.], dtype=float32)
>>> x.at[2].get()
Array(2., dtype=float32)
>>> x.at[20].get()  # out-of-bounds indices clipped
Array(4., dtype=float32)
>>> x.at[20].get(mode='fill')  # out-of-bounds indices filled with NaN
Array(nan, dtype=float32)
>>> x.at[20].get(mode='fill', fill_value=-1)  # custom fill value
Array(-1., dtype=float32)