jax.devices#
- jax.devices(backend=None)[源代码]#
返回给定后端的所有设备的列表。
每个设备都由
Device
的子类表示(例如,CpuDevice
,GpuDevice
)。返回列表的长度等于device_count(backend)
。可以通过比较Device.process_index
与jax.process_index()
返回的值来识别本地设备。如果
backend
为None
,则返回默认后端的所有设备。默认后端通常是'gpu'
或'tpu'
(如果可用),否则为'cpu'
。- 参数:
backend (str | xla_client.Client | None | None) – 这是一个实验性功能,API 可能发生变化。 可选参数,一个表示 xla 后端的字符串:
'cpu'
,'gpu'
, 或'tpu'
。- 返回:
Device 子类的列表。
- 返回类型:
list[xla_client.Device]