jax.experimental.pallas.program_id

内容

jax.experimental.pallas.program_id#

jax.experimental.pallas.program_id(axis)[source]#

返回内核执行在网格给定轴上的位置。

例如,在内核执行中对应于网格坐标 (1, 2) 的二维 grid 中,program_id(axis=0) 返回 1program_id(axis=1) 返回 2

参数:

axis (int) – 要计算程序的网格轴。

返回类型:

jax.Array