jax.experimental.pallas.swap

内容

jax.experimental.pallas.swap#

jax.experimental.pallas.swap(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None, _function_name='swap')[source]#

交换给定索引处的数值并返回旧值。

请参阅 load() 了解参数的含义。

返回值:

交换前 ref 中存储的值。

返回类型:

jax.Array