jax.lax.axis_index

内容

jax.lax.axis_index#

jax.lax.axis_index(axis_name)[source]#

返回映射轴 axis_name 上的索引。

参数:

axis_name – 用于命名映射轴的可散列 Python 对象。

返回:

表示索引的整数。

例如,有 8 个 XLA 设备可用时

>>> from functools import partial
>>> @partial(jax.pmap, axis_name='i')
... def f(_):
...   return lax.axis_index('i')
...
>>> f(np.zeros(4))
Array([0, 1, 2, 3], dtype=int32)
>>> f(np.zeros(8))
Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)
>>> @partial(jax.pmap, axis_name='i')
... @partial(jax.pmap, axis_name='j')
... def f(_):
...   return lax.axis_index('i'), lax.axis_index('j')
...
>>> x, y = f(np.zeros((4, 2)))
>>> print(x)
[[0 0]
[1 1]
[2 2]
[3 3]]
>>> print(y)
[[0 1]
[0 1]
[0 1]
[0 1]]