jax.experimental.pallas.program_id#

jax.experimental.pallas.program_id(axis)[源代码]#

返回网格给定轴上的内核执行位置。

例如,在内核执行中,对于与网格坐标 (1, 2) 对应的 2D 网格program_id(axis=0) 返回 1,而 program_id(axis=1) 返回 2

返回的值是一个形状为 () 且 dtype 为 int32 的数组。

参数:

axis (int) – 计算程序所沿的网格轴。

返回类型:

jax.Array