jax.devices

内容

jax.devices#

jax.devices(backend=None)[source]#

返回给定后端的所有设备的列表。

每个设备都由 Device(例如 CpuDeviceGpuDevice)的子类表示。返回列表的长度等于 device_count(backend)。可以通过将 Device.process_indexjax.process_index() 返回的值进行比较来识别本地设备。

如果 backendNone,则返回默认后端的所有设备。默认后端通常是 'gpu''tpu'(如果可用),否则为 'cpu'

参数:

backend (str | xla_client.Client | None | None) – 这是一个实验性功能,API 可能会发生变化。可选,表示 xla 后端的字符串:'cpu''gpu''tpu'

返回:

Device 子类的列表。

返回类型:

list[xla_client.Device]