jax.default_device#

jax.default_device = <jax._src.config.State object>#

jax_default_device 配置选项的上下文管理器。

配置 JAX 操作的默认设备。设置为 Device 对象(例如 jax.devices("cpu")[0])以将该 Device 用作 JAX 操作和 jit'd 函数调用的默认设备(对多设备计算没有影响,例如 pmapped 函数调用)。设置为 None 以使用系统默认设备。有关设备放置的更多信息,请参见 控制设备上的数据和计算放置

参数:

new_val (Any)