jax.extend.ffi.register_ffi_target

jax.extend.ffi.register_ffi_target#

jax.extend.ffi.register_ffi_target(name, fn, platform='cpu', api_version=1, **kwargs)[source]#

注册一个外部函数目标。

参数:
  • name (str) – 目标的名称。

  • fn (Any) – 包含函数指针的 PyCapsule 对象,或一个 dict,其中键是 FFI 阶段名称(例如 “execute”),值是包含指向该阶段处理程序的指针的 PyCapsule 对象。

  • platform (str) – 目标平台。

  • api_version (int) – 要使用的 XLA 自定义调用 API 版本。支持的版本包括:1(默认)用于类型化 FFI 或 0 用于较早的“自定义调用”API。

  • kwargs (Any) – 任何额外的关键字参数都将直接传递给 register_custom_call_target() 以用于更高级的用例。

返回类型:

None