jax.experimental.pallas 模块

内容

jax.experimental.pallas 模块#

用于 Pallas 的模块,Pallas 是一个用于自定义内核的 JAX 扩展。

请参阅 Pallas 文档:https://jax.ac.cn/en/latest/pallas.html

#

BlockSpec([block_shape, index_map, ...])

指定数组应如何为内核的每次调用切片。

GridSpec([grid, in_specs, out_specs])

jax.experimental.pallas.pallas_call() 编码网格参数。

Slice(start, size[, stride])

具有起始索引和大小的切片。

函数#

pallas_call(kernel, out_shape, *[, ...])

在一些输入上调用 Pallas 内核。

program_id(axis)

返回内核执行在网格给定轴上的位置。

num_programs(axis)

返回网格在给定轴上的大小。

load(x_ref_or_view, idx, *[, mask, other, ...])

返回从给定索引加载的数组。

store(x_ref_or_view, idx, val, *[, mask, ...])

在给定索引处存储一个值。

swap(x_ref_or_view, idx, val, *[, mask, ...])

交换给定索引处的值并返回旧值。

atomic_and(x_ref_or_view, idx, val, *[, mask])

原子地计算 x_ref_or_view[idx] &= val

atomic_add(x_ref_or_view, idx, val, *[, mask])

原子地计算 x_ref_or_view[idx] += val

atomic_cas(ref, cmp, val)

执行ref中值的原子比较和交换操作,使用给定的值。

atomic_max(x_ref_or_view, idx, val, *[, mask])

原子地计算 x_ref_or_view[idx] = max(x_ref_or_view[idx], val)

atomic_min(x_ref_or_view, idx, val, *[, mask])

原子地计算 x_ref_or_view[idx] = min(x_ref_or_view[idx], val)

atomic_or(x_ref_or_view, idx, val, *[, mask])

原子地计算 x_ref_or_view[idx] |= val

atomic_xchg(x_ref_or_view, idx, val, *[, mask])

原子地交换给定值与给定索引处的值。

debug_print(fmt, *args)

从 Pallas 内核内部打印标量值。