jax.experimental.pallas 模块#

Pallas 的模块,Pallas 是一个用于自定义内核的 JAX 扩展。

请参阅 Pallas 文档:https://jax.ac.cn/en/latest/pallas.html

后端#

#

BlockSpec([block_shape, index_map, ...])

指定应如何为内核的每次调用切分数组。

GridSpec([grid, in_specs, out_specs, ...])

jax.experimental.pallas.pallas_call() 编码网格参数。

Slice(start, size[, stride])

一个带有起始索引和大小的切片。

MemoryRef(shape, dtype, memory_space)

类似于 jax.ShapeDtypeStruct,但带有内存空间。

函数#

pallas_call(kernel, out_shape, *[, ...])

在某些输入上调用 Pallas 内核。

program_id(axis)

返回沿网格给定轴的内核执行位置。

num_programs(axis)

返回沿给定轴的网格大小。

load(x_ref_or_view, idx, *[, mask, other, ...])

返回从给定索引加载的数组。

store(x_ref_or_view, idx, val, *[, mask, ...])

在给定索引处存储一个值。

swap(x_ref_or_view, idx, val, *[, mask, ...])

交换给定索引处的值并返回旧值。

atomic_and(x_ref_or_view, idx, val, *[, mask])

原子地计算 x_ref_or_view[idx] &= val

atomic_add(x_ref_or_view, idx, val, *[, mask])

原子地计算 x_ref_or_view[idx] += val

atomic_cas(ref, cmp, val)

执行对 ref 中值的原子比较和交换操作。

atomic_max(x_ref_or_view, idx, val, *[, mask])

原子地计算 x_ref_or_view[idx] = max(x_ref_or_view[idx], val)

atomic_min(x_ref_or_view, idx, val, *[, mask])

原子地计算 x_ref_or_view[idx] = min(x_ref_or_view[idx], val)

atomic_or(x_ref_or_view, idx, val, *[, mask])

原子地计算 x_ref_or_view[idx] |= val

atomic_xchg(x_ref_or_view, idx, val, *[, mask])

将给定值与给定索引处的值进行原子交换。

atomic_xor(x_ref_or_view, idx, val, *[, mask])

原子地计算 x_ref_or_view[idx] ^= val

broadcast_to(a, shape)

debug_print(fmt, *args)

从 Pallas 内核内部打印值。

dot(a, b[, trans_a, trans_b, allow_tf32, ...])

max_contiguous(x, values)

multiple_of(x, values)

run_scoped(f, *types, **kw_types)

调用带有已分配引用的函数并返回结果。

when(condition)