jax.experimental.pallas.num_programs

jax.experimental.pallas.num_programs#

jax.experimental.pallas.num_programs(axis)[source]#

返回沿给定轴的网格大小。

参数:

axis (int)

返回类型:

int | jax.Array