jax.lib 模块

jax.lib 模块#

The jax.lib 包是一组用于在 JAX 的 Python 前端和 XLA 后端之间桥接的内部工具和类型。

jax.lib.xla_bridge#

default_backend()

返回默认 XLA 后端的平台名称。

get_backend([platform])

get_compile_options(num_replicas, num_partitions)

返回要使用的编译选项,这些选项来自标志值。

jax.lib.xla_client#

register_custom_call_target(name, fn[, ...])

注册一个自定义调用目标。