jax.extend.ffi.ffi_call#
- jax.extend.ffi.ffi_call(target_name, result_shape_dtypes, *args, vectorized=False, **kwargs)[source]#
调用外部函数接口 (FFI) 目标。
类似于
pure_callback()
,ffi_call
在vmap()
下的行为取决于vectorized
的值。当vectorized
为True
时,假设 FFI 目标满足:ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs])
。换句话说,在带有额外前导维度的输入上调用 FFI 目标应该返回与在循环中调用它并沿着第零轴堆叠的结果相同。因此,FFI 目标将直接在批处理输入上调用(其中批处理轴是前导维度)。此外,回调应该返回具有对应的前导批处理轴的输出。如果vectorized
为False
(默认行为),则将ffi_call
在vmap()
下进行变换会导致使用ffi_call
作为主体的一个scan()
。- 参数:
target_name (str) – 使用
register_custom_call_target()
注册的 XLA FFI 自定义调用目标的名称。result_shape_dtypes (ResultMetadata | Sequence[ResultMetadata]) – 带有
shape
和dtype
属性的对象或对象序列,它们预计与自定义调用输出或输出的形状和数据类型匹配。ShapeDtypeStruct
通常用于定义result_shape_dtypes
的元素。jax.core.abstract_token
可用于表示令牌类型输出。*args (ArrayLike) – 传递给自定义调用的参数。
vectorized (bool) – 布尔值,指定回调函数是否能够以向量化方式操作,如上所述。
**kwargs (Any) – 作为命名属性传递给自定义调用的关键字参数,使用 XLA 的 FFI 接口。
- 返回值:
一个或多个
Array
对象,其形状和数据类型与result_shape_dtypes
匹配。- 返回类型: