pallas_call (kernel, out_shape, *[, ...])
|
在某些输入上调用 Pallas 内核。 |
program_id (axis)
|
返回沿网格给定轴的内核执行位置。 |
num_programs (axis)
|
返回沿给定轴的网格大小。 |
load (x_ref_or_view, idx, *[, mask, other, ...])
|
返回从给定索引加载的数组。 |
store (x_ref_or_view, idx, val, *[, mask, ...])
|
在给定索引处存储一个值。 |
swap (x_ref_or_view, idx, val, *[, mask, ...])
|
交换给定索引处的值并返回旧值。 |
atomic_and (x_ref_or_view, idx, val, *[, mask])
|
原子地计算 x_ref_or_view[idx] &= val 。 |
atomic_add (x_ref_or_view, idx, val, *[, mask])
|
原子地计算 x_ref_or_view[idx] += val 。 |
atomic_cas (ref, cmp, val)
|
执行对 ref 中值的原子比较和交换操作。 |
atomic_max (x_ref_or_view, idx, val, *[, mask])
|
原子地计算 x_ref_or_view[idx] = max(x_ref_or_view[idx], val) 。 |
atomic_min (x_ref_or_view, idx, val, *[, mask])
|
原子地计算 x_ref_or_view[idx] = min(x_ref_or_view[idx], val) 。 |
atomic_or (x_ref_or_view, idx, val, *[, mask])
|
原子地计算 x_ref_or_view[idx] |= val 。 |
atomic_xchg (x_ref_or_view, idx, val, *[, mask])
|
将给定值与给定索引处的值进行原子交换。 |
atomic_xor (x_ref_or_view, idx, val, *[, mask])
|
原子地计算 x_ref_or_view[idx] ^= val 。 |
broadcast_to (a, shape)
|
|
debug_print (fmt, *args)
|
从 Pallas 内核内部打印值。 |
dot (a, b[, trans_a, trans_b, allow_tf32, ...])
|
|
max_contiguous (x, values)
|
|
multiple_of (x, values)
|
|
run_scoped (f, *types, **kw_types)
|
调用带有已分配引用的函数并返回结果。 |
when (condition)
|
|