jax.experimental.pallas 模块#

Pallas 模块,一个用于自定义内核的 JAX 扩展。

请参阅 Pallas 文档:https://jax.ac.cn/en/latest/pallas.html

后端#

#

BlockSpec([block_shape, index_map, ...])

指定如何为内核的每次调用切分数组。

GridSpec([grid, in_specs, out_specs, ...])

jax.experimental.pallas.pallas_call() 编码网格参数。

Slice(开始, 大小[, 步长])

具有起始索引和大小的切片。

MemoryRef(形状, 数据类型, 内存空间)

类似于 jax.ShapeDtypeStruct,但带有内存空间。

函数#

pallas_call(内核, 输出形状, *[, ...])

在某些输入上调用 Pallas 内核。

program_id(轴)

返回沿网格给定轴的内核执行位置。

num_programs(轴)

返回沿给定轴的网格大小。

load(x_ref_or_view, 索引, *[, 掩码, 其他, ...])

返回从给定索引加载的数组。

store(x_ref_or_view, 索引, 值, *[, 掩码, ...])

在给定索引处存储一个值。

swap(x_ref_or_view, 索引, 值, *[, 掩码, ...])

交换给定索引处的值并返回旧值。

atomic_and(x_ref_or_view, 索引, 值, *[, 掩码])

原子计算 x_ref_or_view[idx] &= val

atomic_add(x_ref_or_view, 索引, 值, *[, 掩码])

原子计算 x_ref_or_view[idx] += val

atomic_cas(引用, 比较值, 新值)

对引用中的值与给定值执行原子比较并交换操作。

atomic_max(x_ref_or_view, 索引, 值, *[, 掩码])

原子计算 x_ref_or_view[idx] = max(x_ref_or_view[idx], val)

atomic_min(x_ref_or_view, 索引, 值, *[, 掩码])

原子计算 x_ref_or_view[idx] = min(x_ref_or_view[idx], val)

atomic_or(x_ref_or_view, 索引, 值, *[, 掩码])

原子计算 x_ref_or_view[idx] |= val

atomic_xchg(x_ref_or_view, 索引, 值, *[, 掩码])

原子地将给定值与给定索引处的值交换。

atomic_xor(x_ref_or_view, 索引, 值, *[, 掩码])

原子计算 x_ref_or_view[idx] ^= val

broadcast_to(a, 形状)

debug_print(格式, *参数)

从 Pallas 内核内部打印值。

dot(a, b[, 转置a, 转置b, 允许tf32, ...])

max_contiguous(x, 值)

multiple_of(x, 值)

run_scoped(f, *类型, **kw_类型)

使用分配的引用调用函数并返回结果。

when(条件)