jax.experimental.pallas
模块#
用于 Pallas 的模块,Pallas 是一个用于自定义内核的 JAX 扩展。
请参阅 Pallas 文档:https://jax.ac.cn/en/latest/pallas.html。
类#
|
指定数组应如何为内核的每次调用切片。 |
|
为 |
|
具有起始索引和大小的切片。 |
函数#
|
在一些输入上调用 Pallas 内核。 |
|
返回内核执行在网格给定轴上的位置。 |
|
返回网格在给定轴上的大小。 |
|
返回从给定索引加载的数组。 |
|
在给定索引处存储一个值。 |
|
交换给定索引处的值并返回旧值。 |
|
原子地计算 |
|
原子地计算 |
|
执行ref中值的原子比较和交换操作,使用给定的值。 |
|
原子地计算 |
|
原子地计算 |
|
原子地计算 |
|
原子地交换给定值与给定索引处的值。 |
|
从 Pallas 内核内部打印标量值。 |