jax.extend.ffi.ffi_call

内容

jax.extend.ffi.ffi_call#

jax.extend.ffi.ffi_call(target_name, result_shape_dtypes, *args, vectorized=False, **kwargs)[source]#

调用外部函数接口 (FFI) 目标。

类似于 pure_callback()ffi_callvmap() 下的行为取决于 vectorized 的值。当 vectorizedTrue 时,假设 FFI 目标满足:ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs])。换句话说,在带有额外前导维度的输入上调用 FFI 目标应该返回与在循环中调用它并沿着第零轴堆叠的结果相同。因此,FFI 目标将直接在批处理输入上调用(其中批处理轴是前导维度)。此外,回调应该返回具有对应的前导批处理轴的输出。如果 vectorizedFalse(默认行为),则将 ffi_callvmap() 下进行变换会导致使用 ffi_call 作为主体的一个 scan()

参数:
  • target_name (str) – 使用 register_custom_call_target() 注册的 XLA FFI 自定义调用目标的名称。

  • result_shape_dtypes (ResultMetadata | Sequence[ResultMetadata]) – 带有 shapedtype 属性的对象或对象序列,它们预计与自定义调用输出或输出的形状和数据类型匹配。 ShapeDtypeStruct 通常用于定义 result_shape_dtypes 的元素。 jax.core.abstract_token 可用于表示令牌类型输出。

  • *args (ArrayLike) – 传递给自定义调用的参数。

  • vectorized (bool) – 布尔值,指定回调函数是否能够以向量化方式操作,如上所述。

  • **kwargs (Any) – 作为命名属性传递给自定义调用的关键字参数,使用 XLA 的 FFI 接口。

返回值:

一个或多个 Array 对象,其形状和数据类型与 result_shape_dtypes 匹配。

返回类型:

Array | list[Array]