jax.experimental.pallas.load

内容

jax.experimental.pallas.load#

jax.experimental.pallas.load(x_ref_or_view, idx, *, mask=None, other=None, cache_modifier=None, eviction_policy=None, volatile=False)[source]#

从给定索引加载的数组。

如果既没有指定 mask 也没有指定 other,则此函数与 JAX 中的 x_ref_or_view[idx] 的语义相同。

参数:
  • x_ref_or_view – 要从中加载的引用。

  • idx – 要使用的索引器。

  • mask – 一个可选的布尔掩码,指定要加载的索引。如果 mask 为 False 且未给出 other,则无法对结果数组中的值做出任何假设。

  • other – 当 mask 为 False 时要使用的可选值。

  • cache_modifier – 待文档化。

  • eviction_policy – 待文档化。

  • volatile – 待文档化。

返回类型:

jax.Array