jax.extend.ffi.pycapsule

目录

jax.extend.ffi.pycapsule#

jax.extend.ffi.pycapsule(funcptr)[source]#

将 ctypes 函数指针包装在 PyCapsule 中。

此函数的主要用途,以及它为何存在于 jax.extend.ffi 子模块中的原因,是包装来自外部编译库的函数调用,以将其注册为 XLA 自定义调用。

示例用法

import ctypes
import jax
from jax.lib import xla_client

libfoo = ctypes.cdll.LoadLibrary('./foo.so')
xla_client.register_custom_call_target(
    name="bar",
    fn=jax.extend.ffi.pycapsule(libfoo.bar),
    platform=PLATFORM,
    api_version=API_VERSION
)
参数:

funcptr – 使用 ctypes 从动态库加载的函数指针。

返回值:

一个不透明的 PyCapsule 对象,包装了 funcptr