jax.experimental.pallas.num_programs# jax.experimental.pallas.num_programs(axis)[source]# 返回沿给定轴的网格大小。 参数: axis (int) 返回类型: int | jax.Array