jax.experimental.pallas.program_id# jax.experimental.pallas.program_id(axis)[源代码]# 返回沿网格给定轴的内核执行位置。 例如,在内核执行中具有与网格坐标 (1, 2) 对应的 2D grid,program_id(axis=0) 返回 1,program_id(axis=1) 返回 2。 返回的值是形状为 () 和 dtype 为 int32 的数组。 参数: axis (int) – 网格的轴,沿该轴计算程序。 返回类型: jax.Array