外部函数接口 (FFI)#

本教程需要 JAX v0.4.31 或更高版本。

虽然使用 JAX 内置的 jax.numpyjax.lax 接口可以轻松高效地实现各种数值运算,但有时通过“外部函数接口”(FFI)显式调用外部编译库会很有用。当特定的操作先前已在优化的 C 或 CUDA 库中实现,并且直接使用 JAX 重新实现这些计算并非易事时,这尤其有用,但它也可用于优化 JAX 程序的运行时或内存性能。话虽如此,FFI 通常应被视为最后手段,因为位于后端的 XLA 编译器或提供更低级别控制的 Pallas 内核语言通常会以较低的开发和维护成本生成高性能代码。

在考虑使用 FFI 时,应注意的一点是,JAX 不会自动知道如何通过外部函数进行微分。这意味着,如果您想将 JAX 的自动微分功能与外部函数一起使用,您还需要提供相关微分规则的实现。我们将在下面讨论一些可能的方法,但重要的是从一开始就指出此限制!

JAX 的 FFI 支持分两部分提供

  1. 一个来自 XLA 的仅头文件 C++ 库,它作为 JAX v0.4.29 的一部分打包,或者可以从 openxla/xla 项目获得,以及

  2. 一个 Python 前端,可在 jax.ffi 子模块中使用。

在本教程中,我们将通过一个简单的示例演示这两个组件的使用,然后讨论针对更复杂用例的一些更低级别的扩展。我们首先介绍 CPU 上的 FFI,然后在下面讨论针对 GPU 或多设备环境的泛化。

此示例和其他一些更高级用例的端到端代码可以在 GitHub 上的 JAX FFI 示例项目中找到,地址为 examples/ffi 在 JAX 存储库中

由于我们将在本教程的末尾演示如何分片 FFI 调用,因此让我们首先设置我们的环境,让 JAX 将其视为具有多个 CPU

import os

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"

一个简单的例子#

为了演示 FFI 接口的使用,我们将实现一个简单的“均方根 (RMS)”归一化函数。RMS 归一化接受一个形状为 \( (N,) \) 的数组 \(x\),并返回

\[ y_n = \frac{x_n}{\sqrt{\frac{1}{N}\sum_{n=1}^N {x_n}^2 + \epsilon}} \]

其中 \(\epsilon\) 是用于数值稳定性的调整参数。

这是一个有点傻的例子,因为它可以使用 JAX 轻松实现,如下所示

import jax
import jax.numpy as jnp


def rms_norm_ref(x, eps=1e-5):
  scale = jnp.sqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True) + eps)
  return x / scale

但是,它刚好足够非平凡,可以用来演示 FFI 的一些关键细节,同时仍然易于理解。我们将使用此参考实现来测试我们下面的 FFI 版本。

后端代码#

首先,我们需要一个使用 FFI 公开的 C++ RMS 归一化实现。这并非旨在具有特别高的性能,但您可以想象,如果您在 C++ 库中对 RMS 归一化有一些新的更好的实现,它可能具有如下的接口。因此,这是 C++ 中 RMS 归一化的一个简单实现

#include <cmath>
#include <cstdint>

float ComputeRmsNorm(float eps, int64_t size, const float *x, float *y) {
  float sm = 0.0f;
  for (int64_t n = 0; n < size; ++n) {
    sm += x[n] * x[n];
  }
  float scale = 1.0f / std::sqrt(sm / float(size) + eps);
  for (int64_t n = 0; n < size; ++n) {
    y[n] = x[n] * scale;
  }
  return scale;
}

并且,对于我们的示例,这是我们想通过 FFI 公开给 JAX 的函数。

C++ 接口#

为了将我们的库函数公开给 JAX 和 XLA,我们需要使用 xla/ffi/api 目录中的仅头文件库提供的 API 编写一个薄包装器。有关此接口的更多信息,请查看 XLA 自定义调用文档。完整的源代码列表可以 在这里 下载,但关键的实现细节在这里重现

#include <functional>
#include <numeric>
#include <utility>

#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"

namespace ffi = xla::ffi;

// A helper function for extracting the relevant dimensions from `ffi::Buffer`s.
// In this example, we treat all leading dimensions as batch dimensions, so this
// function returns the total number of elements in the buffer, and the size of
// the last dimension.
template <ffi::DataType T>
std::pair<int64_t, int64_t> GetDims(const ffi::Buffer<T> &buffer) {
  auto dims = buffer.dimensions();
  if (dims.size() == 0) {
    return std::make_pair(0, 0);
  }
  return std::make_pair(buffer.element_count(), dims.back());
}

// A wrapper function providing the interface between the XLA FFI call and our
// library function `ComputeRmsNorm` above. This function handles the batch
// dimensions by calling `ComputeRmsNorm` within a loop.
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::F32> x,
                       ffi::ResultBuffer<ffi::F32> y) {
  auto [totalSize, lastDim] = GetDims(x);
  if (lastDim == 0) {
    return ffi::Error::InvalidArgument("RmsNorm input must be an array");
  }
  for (int64_t n = 0; n < totalSize; n += lastDim) {
    ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n]));
  }
  return ffi::Error::Success();
}

// Wrap `RmsNormImpl` and specify the interface to XLA. If you need to declare
// this handler in a header, you can use the `XLA_FFI_DECLARE_HANDLER_SYMBOL`
// macro: `XLA_FFI_DECLARE_HANDLER_SYMBOL(RmsNorm)`.
XLA_FFI_DEFINE_HANDLER_SYMBOL(
    RmsNorm, RmsNormImpl,
    ffi::Ffi::Bind()
        .Attr<float>("eps")
        .Arg<ffi::Buffer<ffi::F32>>()  // x
        .Ret<ffi::Buffer<ffi::F32>>()  // y
);

从底部开始,我们使用 XLA 提供的宏 XLA_FFI_DEFINE_HANDLER_SYMBOL 生成一些样板代码,这些代码将扩展为具有适当签名的名为 RmsNorm 的函数。但是,这里重要的是对 ffi::Ffi::Bind() 的调用,我们在其中定义输入和输出类型,以及任何参数的类型。

然后,在 RmsNormImpl 中,我们接受 ffi::Buffer 参数,其中包括有关缓冲区形状的信息以及指向基础数据的指针。在此实现中,我们将缓冲区的所有前导维度视为批次维度,并在最后一个轴上执行 RMS 归一化。GetDims 是一个辅助函数,为这种批处理行为提供支持。我们在 下面 更详细地讨论这种批处理行为,但总体思路是,它可以透明地处理输入参数最左侧维度中的批处理。在本例中,我们将除最后一个轴之外的所有轴都视为批次维度,但其他外部函数可能需要不同数量的非批次维度。

构建和注册 FFI 处理程序#

现在我们已经实现了最小的 FFI 包装器,我们需要将此函数 (RmsNorm) 公开给 Python。在本教程中,我们将 RmsNorm 编译到共享库中,并使用 ctypes 加载它,但另一种常见模式是使用 nanobindpybind11,如下所述。

为了编译共享库,我们在这里使用 CMake,但是您应该能够使用您喜欢的构建系统,而不会遇到太多麻烦。

!cmake -DCMAKE_BUILD_TYPE=Release -B ffi/_build ffi
!cmake --build ffi/_build
!cmake --install ffi/_build
隐藏代码单元输出
-- The CXX compiler identification is GNU 11.4.0
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Check for working CXX compiler: /usr/bin/c++ - skipped
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Found Python: /home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/bin/python3.10 (found suitable version "3.10.15", minimum required is "3.8") found components: Interpreter Development.Module
<string>:1: DeprecationWarning: jax.extend.ffi.include_dir is deprecated, use jax.ffi.include_dir instead.
-- XLA include directory: /home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include
-- Configuring done (1.3s)
-- Generating done (0.0s)
-- Build files have been written to: /home/docs/checkouts/readthedocs.org/user_builds/jax/checkouts/latest/docs/ffi/_build
[ 50%] Building CXX object CMakeFiles/rms_norm.dir/rms_norm.cc.o
In file included from /home/docs/checkouts/readthedocs.org/user_builds/jax/checkouts/latest/docs/ffi/rms_norm.cc:24:
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/ffi.h:654:68: warning: always_inline’ function might not be inlinable []8;;https://gcc.gnu.org/onlinedocs/gcc/Warning-Options.html#index-Wattributes-Wattributes]8;;]
  654 | _ATTRIBUTE_ALWAYS_INLINE std::optional<Buffer<dtype, rank>> DecodeBuffer(
      |                                                             ^~~~~~~~~~~~
In file included from /home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/ffi.h:48,
                 from /home/docs/checkouts/readthedocs.org/user_builds/jax/checkouts/latest/docs/ffi/rms_norm.cc:24:
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/api.h: In function ‘std::ostream& operator<<(std::ostream&, XLA_FFI_ExecutionStage)’:
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/api.h:180:1: warning: control reaches end of non-void function []8;;https://gcc.gnu.org/onlinedocs/gcc/Warning-Options.html#index-Wreturn-type-Wreturn-type]8;;]
  180 | }
      | ^
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/api.h: In function ‘std::ostream& operator<<(std::ostream&, XLA_FFI_AttrType)’:
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/api.h:166:1: warning: control reaches end of non-void function []8;;https://gcc.gnu.org/onlinedocs/gcc/Warning-Options.html#index-Wreturn-type-Wreturn-type]8;;]
  166 | }
      | ^
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/api.h: In function ‘std::ostream& operator<<(std::ostream&, XLA_FFI_DataType)’:
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/api.h:153:1: warning: control reaches end of non-void function []8;;https://gcc.gnu.org/onlinedocs/gcc/Warning-Options.html#index-Wreturn-type-Wreturn-type]8;;]
  153 | }
      | ^
In file included from /home/docs/checkouts/readthedocs.org/user_builds/jax/checkouts/latest/docs/ffi/rms_norm.cc:24:
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/ffi.h: In function ‘std::ostream& xla::ffi::operator<<(std::ostream&, XLA_FFI_ArgType)’:
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/ffi.h:722:1: warning: control reaches end of non-void function []8;;https://gcc.gnu.org/onlinedocs/gcc/Warning-Options.html#index-Wreturn-type-Wreturn-type]8;;]
  722 | }
      | ^
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/ffi.h: In function ‘std::ostream& xla::ffi::operator<<(std::ostream&, XLA_FFI_RetType)’:
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/ffi.h:797:1: warning: control reaches end of non-void function []8;;https://gcc.gnu.org/onlinedocs/gcc/Warning-Options.html#index-Wreturn-type-Wreturn-type]8;;]
  797 | }
      | ^
[100%] Linking CXX shared library librms_norm.so
[100%] Built target rms_norm
-- Install configuration: "Release"
-- Installing: /home/docs/checkouts/readthedocs.org/user_builds/jax/checkouts/latest/docs/ffi/librms_norm.so

有了这个编译好的库,我们现在需要通过 register_ffi_target() 函数向 XLA 注册此处理程序。此函数希望我们的处理程序(指向 C++ 函数 RmsNorm 的函数指针)包装在 PyCapsule 中。JAX 提供了一个辅助函数 pycapsule() 来帮助完成此操作

import ctypes
from pathlib import Path

path = next(Path("ffi").glob("librms_norm*"))
rms_norm_lib = ctypes.cdll.LoadLibrary(path)
jax.ffi.register_ffi_target(
    "rms_norm", jax.ffi.pycapsule(rms_norm_lib.RmsNorm), platform="cpu")

提示

如果您熟悉传统的“自定义调用”API,那么值得注意的是,您还可以通过手动指定关键字参数 api_version=0 来使用 register_ffi_target() 注册自定义调用目标。 register_ffi_target() 的默认 api_version1,这是我们在此处使用的新“类型化” FFI API。

一种替代方法:将处理程序公开给 Python 的一种常见替代模式是使用 nanobindpybind11 来定义一个可以导入的小型 Python 扩展。对于我们这里的示例,nanobind 代码将是

#include <type_traits>

#include "nanobind/nanobind.h"
#include "xla/ffi/api/c_api.h"

namespace nb = nanobind;

template <typename T>
nb::capsule EncapsulateFfiCall(T *fn) {
  // This check is optional, but it can be helpful for avoiding invalid handlers.
  static_assert(std::is_invocable_r_v<XLA_FFI_Error *, T, XLA_FFI_CallFrame *>,
                "Encapsulated function must be and XLA FFI handler");
  return nb::capsule(reinterpret_cast<void *>(fn));
}

NB_MODULE(rms_norm, m) {
  m.def("rms_norm", []() { return EncapsulateFfiCall(RmsNorm); });
}

然后,在 Python 中,我们可以使用以下代码注册此处理程序

# Assuming that we compiled a nanobind extension called `rms_norm`:
import rms_norm as rms_norm_lib

jax.ffi.register_ffi_target("rms_norm", rms_norm_lib.rms_norm(), platform="cpu")

前端代码#

现在我们已经注册了我们的 FFI 处理程序,使用 ffi_call() 函数从 JAX 调用我们的 C++ 库非常简单

import numpy as np


def rms_norm(x, eps=1e-5):
  # We only implemented the `float32` version of this function, so we start by
  # checking the dtype. This check isn't strictly necessary because type
  # checking is also performed by the FFI when decoding input and output
  # buffers, but it can be useful to check types in Python to raise more
  # informative errors.
  if x.dtype != jnp.float32:
    raise ValueError("Only the float32 dtype is implemented by rms_norm")

  call = jax.ffi.ffi_call(
    # The target name must be the same string as we used to register the target
    # above in `register_custom_call_target`
    "rms_norm",

    # In this case, the output of our FFI function is just a single array with
    # the same shape and dtype as the input. We discuss a case with a more
    # interesting output type below.
    jax.ShapeDtypeStruct(x.shape, x.dtype),

    # The `vmap_method` parameter controls this function's behavior under `vmap`
    # as discussed below.
    vmap_method="broadcast_all",
  )

  # Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for
  # the attribute `eps`. Our FFI function expects this to have the C++ `float`
  # type (which corresponds to numpy's `float32` type), and it must be a
  # static parameter (i.e. not a JAX array).
  return call(x, eps=np.float32(eps))


# Test that this gives the same result as our reference implementation
x = jnp.linspace(-0.5, 0.5, 32).reshape((8, 4))
np.testing.assert_allclose(rms_norm(x), rms_norm_ref(x), rtol=1e-5)

此代码单元包含许多内联注释,这些注释应解释此处发生的大部分事情,但有几点值得明确强调。这里的大部分繁重工作都是由 ffi_call() 函数完成的,它告诉 JAX 如何为特定输入集调用外部函数。请务必注意,ffi_call() 的第一个参数必须是一个字符串,该字符串与我们在上面调用 register_custom_call_target 时使用的目标名称匹配。

任何属性(使用上面的 C++ 包装器中的 Attr 定义)都应作为关键字参数传递给 ffi_call()。请注意,我们显式地将 eps 强制转换为 np.float32,因为我们的 FFI 库需要 C float,并且我们不能在这里使用 jax.numpy,因为这些参数必须是静态参数。

ffi_call()vmap_method 参数定义了此 FFI 调用如何与 vmap() 交互,如下所述。

提示

如果您熟悉早期的“自定义调用”接口,您可能会惊讶于我们没有将问题维度(批次大小等)作为参数传递给 ffi_call()。在早期的 API 中,后端没有接收关于输入数组的元数据的机制,但由于 FFI 在 Buffer 对象中包含了维度信息,我们不再需要在降低过程中使用 Python 来计算它。这项更改的一个主要好处是 ffi_call() 可以开箱即用地支持一些简单的 vmap() 语义,如下所述。

使用 vmap 进行批处理#

ffi_call() 使用 vmap_method 参数开箱即用地支持一些简单的 vmap() 语义。pure_callback() 的文档提供了关于 vmap_method 参数的更多详细信息,相同的行为也适用于 ffi_call()

最简单的 vmap_method"sequential"。在这种情况下,当使用 vmap 时,ffi_call 将被重写为以 ffi_call 为主体的 scan()。此实现是通用的,但并行化效果不佳。许多 FFI 调用提供了更有效的批处理行为,在某些简单情况下,可以使用 "expand_dims""broadcast_all" 方法来公开更好的实现。

在这种情况下,由于我们只有一个输入参数,"expand_dims""broadcast_all" 实际上具有相同的行为。使用这些方法的具体假设是,外部函数知道如何处理批处理维度。另一种说法是,假设在批处理输入上调用 ffi_call 的结果等于将 ffi_call 重复应用于批处理输入中的每个元素,大致如下:

ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs])

提示

请注意,当我们有多个输入参数时,情况会变得稍微复杂。为简单起见,我们将在本教程中使用 "broadcast_all",这保证了所有输入都将被广播以具有相同的批处理维度,但也可以实现一个外部函数来处理 "expand_dims" 方法。pure_callback() 的文档包含了一些这方面的示例。

我们的 rms_norm 实现具有适当的语义,并且开箱即用地支持使用 vmap_method="broadcast_all"vmap

np.testing.assert_allclose(jax.vmap(rms_norm)(x), jax.vmap(rms_norm_ref)(x), rtol=1e-5)

我们可以检查 rms_normvmap()jaxpr,以确认它没有使用 scan() 重写。

jax.make_jaxpr(jax.vmap(rms_norm))(x)
{ lambda ; a:f32[8,4]. let
    b:f32[8,4] = ffi_call[
      attributes=(('eps', np.float32(1e-05)),)
      custom_call_api_version=4
      has_side_effect=False
      input_layouts=((1, 0),)
      input_output_aliases=()
      legacy_backend_config=None
      output_layouts=((1, 0),)
      result_avals=(ShapedArray(float32[8,4]),)
      target_name=rms_norm
      vectorized=Deprecated
      vmap_method=broadcast_all
    ] a
  in (b,) }

使用 vmap_method="sequential",对 ffi_call 进行 vmap 处理将回退到以 ffi_call 为主体的 jax.lax.scan()

def rms_norm_sequential(x, eps=1e-5):
  return jax.ffi.ffi_call(
    "rms_norm",
    jax.ShapeDtypeStruct(x.shape, x.dtype),
    vmap_method="sequential",
  )(x, eps=np.float32(eps))


jax.make_jaxpr(jax.vmap(rms_norm_sequential))(x)
{ lambda ; a:f32[8,4]. let
    b:f32[8,4] = scan[
      _split_transpose=False
      jaxpr={ lambda ; c:f32[4]. let
          d:f32[4] = ffi_call[
            attributes=(('eps', np.float32(1e-05)),)
            custom_call_api_version=4
            has_side_effect=False
            input_layouts=((0,),)
            input_output_aliases=()
            legacy_backend_config=None
            output_layouts=((0,),)
            result_avals=(ShapedArray(float32[4]),)
            target_name=rms_norm
            vectorized=Deprecated
            vmap_method=sequential
          ] c
        in (d,) }
      length=8
      linear=(False,)
      num_carry=0
      num_consts=0
      reverse=False
      unroll=1
    ] a
  in (b,) }

如果您的外部函数提供了此简单 vmap_method 参数不支持的有效批处理规则,则也可以使用实验性的 custom_vmap 接口定义更灵活的自定义 vmap 规则,但也值得在 JAX 问题跟踪器上提交一个问题来描述您的用例。

微分#

与批处理不同,ffi_call() 不提供对外部函数的自动微分 (AD) 的任何默认支持。就 JAX 而言,外部函数是一个黑盒子,无法对其进行检查以确定微分时的适当行为。因此,ffi_call() 用户有责任定义自定义导数规则。

有关自定义导数规则的更多详细信息,请参见 自定义导数教程,但用于为外部函数实现微分的最常见模式是定义一个 custom_vjp(),它本身会调用外部函数。在这种情况下,我们实际上定义了两个新的 FFI 调用:

  1. rms_norm_fwd 返回两个输出:(a) “原始”结果,以及 (b) 用于后向传递的“残差”。

  2. rms_norm_bwd 接收残差和输出的余切,并返回输入的余切。

我们不会深入探讨 RMS 归一化后向传递的细节,但请查看 C++ 源代码,了解这些函数如何在后端实现。这里要强调的主要一点是,计算出的“残差”的形状与原始输出不同,因此,在对 res_norm_fwdffi_call() 中,输出类型有两个形状不同的元素。

此自定义导数规则可以如下方式连接起来:

jax.ffi.register_ffi_target(
  "rms_norm_fwd", jax.ffi.pycapsule(rms_norm_lib.RmsNormFwd), platform="cpu"
)
jax.ffi.register_ffi_target(
  "rms_norm_bwd", jax.ffi.pycapsule(rms_norm_lib.RmsNormBwd), platform="cpu"
)


def rms_norm_fwd(x, eps=1e-5):
  y, res = jax.ffi.ffi_call(
    "rms_norm_fwd",
    (
      jax.ShapeDtypeStruct(x.shape, x.dtype),
      jax.ShapeDtypeStruct(x.shape[:-1], x.dtype),
    ),
    vmap_method="broadcast_all",
  )(x, eps=np.float32(eps))
  return y, (res, x)


def rms_norm_bwd(eps, res, ct):
  del eps
  res, x = res
  assert res.shape == ct.shape[:-1]
  assert x.shape == ct.shape
  return (
    jax.ffi.ffi_call(
      "rms_norm_bwd",
      jax.ShapeDtypeStruct(ct.shape, ct.dtype),
      vmap_method="broadcast_all",
    )(res, x, ct),
  )


rms_norm = jax.custom_vjp(rms_norm, nondiff_argnums=(1,))
rms_norm.defvjp(rms_norm_fwd, rms_norm_bwd)

# Check that this gives the right answer when compared to the reference version
ct_y = jnp.ones_like(x)
np.testing.assert_allclose(
  jax.vjp(rms_norm, x)[1](ct_y), jax.vjp(rms_norm_ref, x)[1](ct_y), rtol=1e-5
)

此时,我们可以为许多 JAX 应用程序透明地使用我们新的 rms_norm 函数,并且它会在诸如 vmap()grad() 之类的标准 JAX 函数转换下进行适当的转换。此示例不支持的一件事是前向模式 AD(例如,jax.jvp()),因为 custom_vjp() 仅限于反向模式。JAX 目前没有公开用于同时自定义前向模式和反向模式 AD 的公共 API,但此类 API 已在路线图上,因此如果您在实践中遇到此限制,请 提交一个问题来描述您的用例。

此示例不支持的另一个 JAX 功能是高阶 AD。可以通过将上述 res_norm_bwd 函数包装在 jax.custom_jvp()jax.custom_vjp() 装饰器中来解决此问题,但我们不会在此处详细介绍该高级用例。

GPU 上的 FFI 调用#

到目前为止,我们仅与在 CPU 上运行的外部函数进行交互,但 JAX 的 FFI 也支持调用 GPU 代码。由于此文档页面是在无法访问 GPU 的计算机上自动生成的,因此我们无法在此处执行任何特定于 GPU 的示例,但我们将介绍要点。

在为 CPU 定义 FFI 包装器时,我们使用的函数签名是:

ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::F32> x,
                       ffi::ResultBuffer<ffi::F32> y)

要更新此签名以与 CUDA 内核接口,此签名将变为:

ffi::Error RmsNormImpl(cudaStream_t stream, float eps,
                       ffi::Buffer<ffi::F32> x,
                       ffi::ResultBuffer<ffi::F32> y)

并且处理程序定义已更新,在其绑定中包含一个 Ctx

XLA_FFI_DEFINE_HANDLER(
    RmsNorm, RmsNormImpl,
    ffi::Ffi::Bind()
        .Ctx<ffi::PlatformStream<cudaStream_t>>()
        .Attr<float>("eps")
        .Arg<ffi::Buffer<ffi::F32>>()  // x
        .Ret<ffi::Buffer<ffi::F32>>()  // y
);

然后,RmsNormImpl 可以使用 CUDA 流来启动 CUDA 内核。

在前端,注册代码将被更新以指定适当的平台

jax.ffi.register_ffi_target(
  "rms_norm_cuda", rms_norm_lib_cuda.rms_norm(), platform="CUDA"
)

支持多个平台#

为了支持在 GPU 和 CPU 上运行我们的 rms_norm 函数,我们可以将上面的实现与 jax.lax.platform_dependent() 函数结合使用

def rms_norm_cross_platform(x, eps=1e-5):
  assert x.dtype == jnp.float32
  out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)

  def impl(target_name):
    return lambda x: jax.ffi.ffi_call(
      target_name,
      out_type,
      vmap_method="broadcast_all",
    )(x, eps=np.float32(eps))

  return jax.lax.platform_dependent(x, cpu=impl("rms_norm"), cuda=impl("rms_norm_cuda"))


np.testing.assert_allclose(rms_norm_cross_platform(x), rms_norm_ref(x), rtol=1e-5)

此版本的函数将根据运行时平台调用适当的 FFI 目标。

顺便提一下,值得注意的是,虽然 jaxpr 和降低后的 HLO 都包含对两个 FFI 目标的引用

jax.make_jaxpr(rms_norm_cross_platform)(x)
{ lambda ; a:f32[8,4]. let
    b:i32[] = platform_index[has_default=False platforms=(('cpu',), ('cuda',))] 
    c:i32[] = clamp 0 b 1
    d:f32[8,4] = cond[
      branches=(
        { lambda ; e:f32[8,4]. let
            f:f32[8,4] = ffi_call[
              attributes=(('eps', np.float32(1e-05)),)
              custom_call_api_version=4
              has_side_effect=False
              input_layouts=((1, 0),)
              input_output_aliases=()
              legacy_backend_config=None
              output_layouts=((1, 0),)
              result_avals=(ShapedArray(float32[8,4]),)
              target_name=rms_norm
              vectorized=Deprecated
              vmap_method=broadcast_all
            ] e
          in (f,) }
        { lambda ; g:f32[8,4]. let
            h:f32[8,4] = ffi_call[
              attributes=(('eps', np.float32(1e-05)),)
              custom_call_api_version=4
              has_side_effect=False
              input_layouts=((1, 0),)
              input_output_aliases=()
              legacy_backend_config=None
              output_layouts=((1, 0),)
              result_avals=(ShapedArray(float32[8,4]),)
              target_name=rms_norm_cuda
              vectorized=Deprecated
              vmap_method=broadcast_all
            ] g
          in (h,) }
      )
    ] c a
  in (d,) }
print(jax.jit(rms_norm_cross_platform).lower(x).as_text().strip())
module @jit_rms_norm_cross_platform attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<8x4xf32>) -> (tensor<8x4xf32> {jax.result_info = ""}) {
    %c = stablehlo.constant dense<0> : tensor<i32>
    %c_0 = stablehlo.constant dense<0> : tensor<i32>
    %c_1 = stablehlo.constant dense<1> : tensor<i32>
    %0 = stablehlo.clamp %c_0, %c, %c_1 : tensor<i32>
    %1 = "stablehlo.case"(%0) ({
      %2 = stablehlo.custom_call @rms_norm(%arg0) {backend_config = "", mhlo.backend_config = {eps = 9.99999974E-6 : f32}, operand_layouts = [dense<[1, 0]> : tensor<2xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>]} : (tensor<8x4xf32>) -> tensor<8x4xf32>
      stablehlo.return %2 : tensor<8x4xf32>
    }, {
      %2 = stablehlo.custom_call @rms_norm_cuda(%arg0) {backend_config = "", mhlo.backend_config = {eps = 9.99999974E-6 : f32}, operand_layouts = [dense<[1, 0]> : tensor<2xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>]} : (tensor<8x4xf32>) -> tensor<8x4xf32>
      stablehlo.return %2 : tensor<8x4xf32>
    }) : (tensor<i32>) -> tensor<8x4xf32>
    return %1 : tensor<8x4xf32>
  }
}

但当函数被编译时,适当的 FFI 已被选中

print(jax.jit(rms_norm_cross_platform).lower(x).as_text(dialect="hlo").strip())
HloModule jit_rms_norm_cross_platform, entry_computation_layout={(f32[8,4]{1,0})->f32[8,4]{1,0}}

ENTRY main.3 {
  Arg_0.1 = f32[8,4]{1,0} parameter(0)
  ROOT custom-call.2 = f32[8,4]{1,0} custom-call(Arg_0.1), custom_call_target="rms_norm", operand_layout_constraints={f32[8,4]{1,0}}, api_version=API_VERSION_TYPED_FFI
}

并且使用 jax.lax.platform_dependent() 不会产生运行时开销,并且编译后的程序不会包含对不可用 FFI 目标的任何引用。

分片#

JAX 的大多数大型用户都使用其 API 在多个设备上进行分布式计算。正如在 并行计算简介 中讨论的那样,JAX 中的并行性通过跨设备分片数据来控制,并且大多数 JAX 操作可以在任何受支持的并行编程范例(从自动到完全手动)中使用。但是,对于 FFI 调用来说,情况稍微复杂一些。由于 FFI 调用的内部机制对于 JAX 和 XLA 都是不透明的,因此当数据被分片时,FFI 调用通常不会显示最佳(甚至良好的)性能。

在深入了解 FFI 细节之前,让我们考虑一下我们的纯 JAX 参考 RMS 归一化实现(本文档顶部定义的 rms_norm_ref 函数)在分片输入下的行为。如上所述,我们的实现将输入的所有前导轴视为批次维度,并且归一化沿最后一个轴执行。这意味着,如果数据沿任何批次维度分片,但在最后一个维度上复制,则不需要通信。这可以通过沿其第一个维度分片我们上面的二维测试数据,并检查已编译的 HLO 中是否存在诸如 all-gatherall-reduce 等操作来查看。

from jax.sharding import PartitionSpec as P

assert len(jax.devices()) == 4  # Set using the XLA_FLAGS environment variable
mesh = jax.make_mesh((4,), ("x",))

batch_shd = jax.NamedSharding(mesh, P("x", None))
x_batch_shd = jax.device_put(x, batch_shd)
hlo_batch = jax.jit(rms_norm_ref, out_shardings=batch_shd).lower(x_batch_shd).compile().as_text()
assert "all-" not in hlo_batch

但是,如果数据沿最后一个轴分片,则需要通信(在这种情况下为 all-reduce)来计算归一化中的总和

data_shd = jax.NamedSharding(mesh, P(None, "x"))
x_data_shd = jax.device_put(x, data_shd)
hlo_data = jax.jit(rms_norm_ref, out_shardings=data_shd).lower(x_data_shd).compile().as_text()
assert "all-reduce" in hlo_data

现在,如果我们尝试天真地使用我们 FFI 版本的相同模型,它可以正常运行并获得正确的答案

output = jax.jit(rms_norm, out_shardings=batch_shd)(x_batch_shd)
np.testing.assert_allclose(output, rms_norm_ref(x), rtol=1e-5)

但是,如果你查看编译后的 HLO(为清楚起见,省略了辅助函数),你会看到

  1. 首先通过 all-gather 操作将数据完全复制到每个设备上,

  2. 然后在每个设备上的完整数据集上执行 FFI 调用,以及

  3. 输出被切片以丢弃未使用的部分。

hlo = jax.jit(rms_norm, out_shardings=batch_shd).lower(x_batch_shd).compile().as_text().strip()
print(hlo.split("\n\n")[-1])
ENTRY %main.5_spmd (param: f32[2,4]) -> f32[2,4] {
  %param = f32[2,4]{1,0} parameter(0), sharding={devices=[4,1]<=[4]}, metadata={op_name="x"}
  %all-gather = f32[8,4]{1,0} all-gather(f32[2,4]{1,0} %param), channel_id=1, replica_groups=[1,4]<=[4], dimensions={0}, use_global_device_ids=true, metadata={op_name="jit(rms_norm)/jit(main)/ffi_call" source_file="/tmp/ipykernel_924/3540880311.py" source_line=32}
  %custom-call.0 = f32[8,4]{1,0} custom-call(f32[8,4]{1,0} %all-gather), custom_call_target="rms_norm", operand_layout_constraints={f32[8,4]{1,0}}, api_version=API_VERSION_TYPED_FFI, metadata={op_name="jit(rms_norm)/jit(main)/ffi_call" source_file="/tmp/ipykernel_924/3540880311.py" source_line=32}, backend_config={eps = 9.99999974E-6 : f32}
  %partition-id = u32[] partition-id()
  ROOT %multiply_dynamic-slice_fusion = f32[2,4]{1,0} fusion(f32[8,4]{1,0} %custom-call.0, u32[] %partition-id), kind=kLoop, calls=%fused_computation
}

显然(对我们来说!)这不是此函数的最佳分区,但这是 JAX/XLA 在给定信息的情况下所能做的最好的。

为了生成更好的分区逻辑,我们可以使用 shard_map()custom_partitioning(),我们在此处讨论这两个选项。话虽如此,为所有输入生成最佳分区并非易事,因为有时这需要算法更改。具体来说,让我们添加对“批次分区”的支持,该支持处理数据在批次维度上分片的情况,但在最后一个维度上分片总是需要重新分片。

使用 shard_map#

如果您通过 shard_map() 使用手动分片控制,则程序中的任何 FFI 调用都应已适当分区

from functools import partial
from jax.experimental.shard_map import shard_map

@partial(shard_map, mesh=mesh, in_specs=P("x", None), out_specs=P("x", None))
def rms_norm_shmap(x):
  return rms_norm(x)

np.testing.assert_allclose(rms_norm_shmap(x_batch_shd), rms_norm_ref(x), rtol=1e-5)
print(jax.jit(rms_norm_shmap, out_shardings=batch_shd).lower(x_batch_shd).compile().as_text().strip())
HloModule jit_rms_norm_shmap, is_scheduled=true, entry_computation_layout={(f32[2,4]{1,0})->f32[2,4]{1,0}}, num_partitions=4

ENTRY %main.12_spmd (param: f32[2,4]) -> f32[2,4] {
  %param = f32[2,4]{1,0} parameter(0), sharding={devices=[4,1]<=[4]}, metadata={op_name="x"}
  ROOT %custom-call.1 = f32[2,4]{1,0} custom-call(f32[2,4]{1,0} %param), custom_call_target="rms_norm", operand_layout_constraints={f32[2,4]{1,0}}, api_version=API_VERSION_TYPED_FFI, metadata={op_name="jit(rms_norm_shmap)/jit(main)/jit(shmap_body)/ffi_call" source_file="/tmp/ipykernel_924/3540880311.py" source_line=32}, backend_config={eps = 9.99999974E-6 : f32}
}

正如您在此程序中看到的那样,如果输入和输出分片与 shard_map 规范匹配,则不需要通信,并且 FFI 调用在数据的适当分片子集上执行。

您还可以使用与 shard_map 规范不匹配的分片输入和输出,但是(与 FFI 无关)这将需要重新分片,正如编译后的 HLO 中的 all-to-all 操作所见

hlo_data_shmap = jax.jit(rms_norm_shmap, out_shardings=data_shd).lower(x_data_shd).compile().as_text()
assert "all-to-all" in hlo_data_shmap

使用 custom partitioning#

如果您不能使用 shard_map(),另一种方法是使用 custom_partitioning(),它支持通过 jax.jit() 进行自动并行化。custom_partitioning() 通过在 XLA 编译器的分区传递中添加 Python 回调来工作,这允许非常灵活的逻辑,但也带来一些粗糙的边缘。我们不会在此处详细介绍这些注意事项,但您应该注意的主要问题是

  1. 当与 JAX 的 持久编译缓存 一起使用时,custom_partitioning 可能会导致意外的缓存未命中。可以使用 jax_remove_custom_partitioning_ptr_from_cache_key 配置标志来缓解这种情况,但这也不总是合适的。

  2. 调试 custom_partitioning 逻辑可能很乏味,因为 Python 错误并不总是传播,而是导致你的 Python 进程退出。话虽如此,任何异常都将显示在进程日志中,因此您应该能够在那里找到它们。

综上所述,以下是如何使用 custom_partitioning() 包装我们的 FFI 实现的 rms_norm

from jax.experimental.custom_partitioning import custom_partitioning

@partial(custom_partitioning, static_argnums=(1,))
def rms_norm_partitioned(x, eps=1e-5):
  return rms_norm(x, eps=eps)

def replicate_sharding_on_last_dim(mesh, sharding, target_info):
  # Our implementation supports trivial sharding on any batch dimensions, but the data
  # must be replicated on the last (non-batch) dimension.
  rank = len(target_info.shape)
  num_batch_dims = min(len(sharding.spec), rank - 1)

  # The Nones here indicate which dimensions should be replicated.
  names = tuple(sharding.spec[:num_batch_dims]) + (None,) * (rank - num_batch_dims)
  return jax.NamedSharding(mesh, P(*names))

def rms_norm_infer_sharding_from_operands(eps, mesh, args_info, result_info):
  del eps  # unused
  arg_info, = args_info
  result_sharding = replicate_sharding_on_last_dim(mesh, arg_info.sharding, result_info)

  # In this case, we only have a single output, but the return value from this function
  # must have the same pytree structure as the output from the underlying function
  # (`rms_norm` in this case).
  return result_sharding

def rms_norm_partition(eps, mesh, args_info, result_info):
  arg_info, = args_info
  arg_sharding = replicate_sharding_on_last_dim(mesh, arg_info.sharding, arg_info)
  result_sharding = replicate_sharding_on_last_dim(mesh, arg_info.sharding, result_info)

  # This is the function that computes the partitioned model on the appropriate subset
  # of the data.
  def partitioned_rms_norm(x):
    return rms_norm(x, eps=eps)

  # Note that the third element of our returned tuple must be the shardings for the
  # _outputs_ and its pytree structure must match the output of `rms_norm`. Similarly,
  # the fourth element must have the same pytree structure as the _inputs_ to
  # `rms_norm`. In this case, there is only one input, but it must be returned within
  # a `tuple` anyways.
  return mesh, partitioned_rms_norm, result_sharding, (arg_sharding,)

rms_norm_partitioned.def_partition(
    infer_sharding_from_operands=rms_norm_infer_sharding_from_operands,
    partition=rms_norm_partition,
)

output = jax.jit(rms_norm_partitioned, out_shardings=batch_shd)(x_batch_shd)
np.testing.assert_allclose(output, rms_norm_ref(x), rtol=1e-5)
print(jax.jit(rms_norm_partitioned, out_shardings=batch_shd).lower(x_batch_shd).compile().as_text().strip())
HloModule jit__unnamed_wrapped_function_, is_scheduled=true, entry_computation_layout={(f32[2,4]{1,0})->f32[2,4]{1,0}}, num_partitions=4

ENTRY %main.5_spmd (param: f32[2,4]) -> f32[2,4] {
  %param = f32[2,4]{1,0} parameter(0), sharding={devices=[4,1]<=[4]}, metadata={op_name="args[0]"}
  ROOT %custom-call.0 = f32[2,4]{1,0} custom-call(f32[2,4]{1,0} %param), custom_call_target="rms_norm", operand_layout_constraints={f32[2,4]{1,0}}, api_version=API_VERSION_TYPED_FFI, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/custom_partitioning" source_file="/tmp/ipykernel_924/1708142274.py" source_line=49}, backend_config={eps = 9.99999974E-6 : f32}
}

从上面的编译程序中可以看到,当输入在批次维度上分片时,此 custom_partitioning 逻辑产生的程序与上面的 shard_map 版本完全相同。

但是,值得注意的是,当输入沿数据维度分片时,行为是不同的。当在 shard_map 下使用时,数据在批次维度上重新分片,而使用 custom_partitioning 时,数据被收集到每个设备上。

hlo_data_partitioned = jax.jit(rms_norm_partitioned, out_shardings=data_shd).lower(x_data_shd).compile().as_text().strip()
assert "all-gather" in hlo_data_partitioned

为了也支持后向传递的自动并行化,我们还需要为 rms_norm_fwdrms_norm_bwd 编写(类似的)custom_partitioning() 规则,但我们将这些留给读者作为练习。

高级主题#

本教程涵盖了使用 JAX 的 FFI 启动和运行所需的大部分基本步骤,但高级用例可能需要更多功能。我们将把这些主题留给未来的教程,但这里有一些可能有用的参考资料

  • 支持多种数据类型:在本教程的示例中,我们限制为仅支持 float32 输入和输出,但许多用例需要支持多种不同的输入类型。处理此问题的一种方法是为所有支持的输入类型注册不同的 FFI 目标,然后使用 Python 根据输入类型为 jax.ffi.ffi_call() 选择适当的目标。但是,根据支持案例的组合情况,这种方法可能会很快变得难以控制。因此,也可以定义 C++ 处理程序来接受 ffi::AnyBuffer 而不是 ffi::Buffer<Dtype>。然后,输入缓冲区将包含一个 element_type() 方法,该方法可用于定义后端中适当的数据类型调度逻辑。

  • 有状态的外部函数:也可以使用 FFI 来包装具有关联状态的函数。 XLA 测试套件中包含一个低级示例,并且未来的教程将包含更多详细信息。