jax.lib 模块#

jax.lib 包是一组内部工具和类型,用于桥接 JAX 的 Python 前端和 XLA 后端。

jax.lib.xla_bridge#

get_backend([platform])

get_compile_options(num_replicas, num_partitions)

返回要使用的编译选项,该选项来自标志值。

jax.lib.xla_client#

register_custom_call_target(name, fn[, ...])

注册自定义调用目标。