jax.local_devices#
- jax.local_devices(process_index=None, backend=None, host_id=None)[源代码]#
类似于
jax.devices()
,但只返回给定进程本地的设备。如果
process_index
为None
,则返回此进程的本地设备。- 参数:
process_index (int | None) – 进程的整数索引。进程索引可以通过
len(jax.process_count())
获取。backend (str | xla_client.Client | None) – 这是一个实验性功能,API 很可能会更改。 可选参数,表示 xla 后端的字符串:
'cpu'
,'gpu'
, 或'tpu'
。host_id (int | None)
- 返回值:
Device 子类的列表。
- 返回类型:
list[xla_client.Device]