jax.local_devices

jax.local_devices#

jax.local_devices(process_index=None, backend=None, host_id=None)[source]#

类似于 jax.devices(),但只返回特定进程的本地设备。

如果 process_indexNone,则返回当前进程的本地设备。

参数:
  • process_index (int | None) – 进程的整数索引。可以通过 len(jax.process_count()) 获取进程索引。

  • 后端 (str | xla_client.Client | None) – 这是一个实验性功能,API 可能会发生变化。可选,一个字符串,表示 XLA 后端:'cpu''gpu''tpu'

  • 主机 ID (int | None)

返回:

设备子类的列表。

返回类型:

list[xla_client.Device]