jax.experimental.pallas.swap
jax.experimental.pallas.swap
-
jax.experimental.pallas.swap(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None, _function_name='swap')[source]
交换给定索引处的数值并返回旧值。
请参阅 load()
了解参数的含义。
- 返回值:
交换前 ref 中存储的值。
- 返回类型:
jax.Array