jax.default_device

jax.default_device#

jax.default_device = <jax._src.config.State object>#

用于 jax_default_device 配置选项的上下文管理器。

配置 JAX 操作的默认设备。设置为 Device 对象(例如 jax.devices("cpu")[0])将使用该 Device 作为 JAX 操作和 jit 函数调用的默认设备(对多设备计算没有影响,例如 pmapped 函数调用)。设置为 None 将使用系统默认设备。有关设备放置的更多信息,请参阅 控制数据和计算在设备上的放置

参数:

new_val (Any)