jax.devices#

jax.devices(backend=None)[源代码]#

返回给定后端的所有设备的列表。

每个设备都由 Device 的子类表示(例如,CpuDeviceGpuDevice)。返回列表的长度等于 device_count(backend)。可以通过比较 Device.process_indexjax.process_index() 返回的值来识别本地设备。

如果 backendNone,则返回默认后端的所有设备。默认后端通常是 'gpu''tpu' (如果可用),否则为 'cpu'

参数:

backend (str | xla_client.Client | None | None) – 这是一个实验性功能,API 可能发生变化。 可选参数,一个表示 xla 后端的字符串: 'cpu', 'gpu', 或 'tpu'

返回:

Device 子类的列表。

返回类型:

list[xla_client.Device]