外部函数接口 (FFI)#
本教程需要 JAX v0.4.31 或更高版本。
尽管可以使用 JAX 内置的 jax.numpy
和 jax.lax
接口轻松高效地实现各种数值运算,但有时通过“外部函数接口”(FFI)显式调用外部编译库会很有用。当特定操作已在优化的 C 或 CUDA 库中实现,并且直接使用 JAX 重新实现这些计算并非易事时,这尤其有用。此外,它也可以用于优化 JAX 程序的运行时或内存性能。话虽如此,通常应将 FFI 视为最后的选择,因为位于后端的 XLA 编译器或提供更低级别控制的 Pallas 内核语言通常会以更低的开发和维护成本生成性能良好的代码。
在考虑使用 FFI 时应考虑的一点是,JAX 不会自动知道如何通过外部函数进行微分。这意味着,如果您想将 JAX 的自动微分功能与外部函数一起使用,则还需要提供相关微分规则的实现。我们将在下面讨论一些可能的方法,但从一开始就指出这个限制非常重要!
JAX 的 FFI 支持分为两部分
XLA 的一个仅包含头文件的 C++ 库,作为 JAX v0.4.29 的一部分打包,或可从 openxla/xla 项目获得,以及
一个 Python 前端,在
jax.extend.ffi
子模块中可用。
在本教程中,我们将通过一个简单的示例演示这两个组件的使用,然后讨论一些用于更复杂用例的低级扩展。我们首先介绍 CPU 上的 FFI,并在下面讨论对 GPU 或多设备环境的推广。
此示例和一些更高级用例的端到端代码可以在 GitHub 上的 JAX FFI 示例项目中找到,网址为 JAX 存储库中的 examples/ffi
。
一个简单的例子#
为了演示 FFI 接口的使用,我们将实现一个简单的“均方根 (RMS)”归一化函数。RMS 归一化取一个形状为 \( (N,) \) 的数组 \(x\),并返回
其中 \(\epsilon\) 是一个用于数值稳定性的调整参数。
这是一个有点愚蠢的例子,因为它可以使用 JAX 轻松实现,如下所示
import jax
import jax.numpy as jnp
def rms_norm_ref(x, eps=1e-5):
scale = jnp.sqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True) + eps)
return x / scale
但是,它对于演示 FFI 的一些关键细节来说已经足够不平凡,同时仍然易于理解。我们将使用此参考实现来测试我们下面的 FFI 版本。
后端代码#
首先,我们需要用 C++ 实现 RMS 归一化,我们将使用 FFI 公开它。这并不是要特别注重性能,但您可以想象,如果您在 C++ 库中实现了 RMS 归一化的某些新的更好的实现,则它可能具有如下所示的接口。因此,这里有一个用 C++ 实现 RMS 归一化的简单实现
#include <cmath>
#include <cstdint>
float ComputeRmsNorm(float eps, int64_t size, const float *x, float *y) {
float sm = 0.0f;
for (int64_t n = 0; n < size; ++n) {
sm += x[n] * x[n];
}
float scale = 1.0f / std::sqrt(sm / float(size) + eps);
for (int64_t n = 0; n < size; ++n) {
y[n] = x[n] * scale;
}
return scale;
}
并且,对于我们的示例,这是我们想要通过 FFI 公开给 JAX 的函数。
C++ 接口#
为了将我们的库函数公开给 JAX 和 XLA,我们需要使用 xla/ffi/api
目录中的仅包含头文件的库提供的 API 编写一个薄包装器。有关此接口的更多信息,请查看 XLA 自定义调用文档。完整的源代码列表可以在此处下载,但关键的实现细节在此处重现
#include <functional>
#include <numeric>
#include <utility>
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"
namespace ffi = xla::ffi;
// A helper function for extracting the relevant dimensions from `ffi::Buffer`s.
// In this example, we treat all leading dimensions as batch dimensions, so this
// function returns the total number of elements in the buffer, and the size of
// the last dimension.
template <ffi::DataType T>
std::pair<int64_t, int64_t> GetDims(const ffi::Buffer<T> &buffer) {
auto dims = buffer.dimensions();
if (dims.size() == 0) {
return std::make_pair(0, 0);
}
return std::make_pair(buffer.element_count(), dims.back());
}
// A wrapper function providing the interface between the XLA FFI call and our
// library function `ComputeRmsNorm` above. This function handles the batch
// dimensions by calling `ComputeRmsNorm` within a loop.
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::F32> x,
ffi::ResultBuffer<ffi::F32> y) {
auto [totalSize, lastDim] = GetDims(x);
if (lastDim == 0) {
return ffi::Error::InvalidArgument("RmsNorm input must be an array");
}
for (int64_t n = 0; n < totalSize; n += lastDim) {
ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n]));
}
return ffi::Error::Success();
}
// Wrap `RmsNormImpl` and specify the interface to XLA. If you need to declare
// this handler in a header, you can use the `XLA_FFI_DECLARE_HANDLER_SYMBOL`
// macro: `XLA_FFI_DECLARE_HANDLER_SYMBOL(RmsNorm)`.
XLA_FFI_DEFINE_HANDLER_SYMBOL(
RmsNorm, RmsNormImpl,
ffi::Ffi::Bind()
.Attr<float>("eps")
.Arg<ffi::Buffer<ffi::F32>>() // x
.Ret<ffi::Buffer<ffi::F32>>() // y
);
从底部开始,我们使用 XLA 提供的宏 XLA_FFI_DEFINE_HANDLER_SYMBOL
来生成一些样板代码,这些样板代码将扩展为一个名为 RmsNorm
且具有适当签名的函数。但是,这里的重要内容都在对 ffi::Ffi::Bind()
的调用中,我们在其中定义了输入和输出类型以及任何参数的类型。
然后,在 RmsNormImpl
中,我们接受 ffi::Buffer
参数,这些参数包括有关缓冲区形状和指向底层数据的指针的信息。在此实现中,我们将缓冲区的所有前导维度视为批次维度,并在最后一个轴上执行 RMS 归一化。GetDims
是一个辅助函数,用于支持此批处理行为。我们将在下面更详细地讨论此批处理行为,但总体思路是,透明地处理输入参数最左侧维度中的批处理会很有用。在本例中,我们将除最后一个轴之外的所有轴都视为批次维度,但其他外部函数可能需要不同数量的非批次维度。
构建和注册 FFI 处理程序#
现在我们已经实现了最小的 FFI 包装器,我们需要将此函数 (RmsNorm
) 公开给 Python。在本教程中,我们将 RmsNorm
编译为共享库,并使用 ctypes 加载它,但另一种常见的模式是使用 nanobind 或 pybind11,如下所述。
要编译共享库,我们在此处使用 CMake,但您应该可以使用您喜欢的构建系统,而不会遇到太多麻烦。
!cmake -DCMAKE_BUILD_TYPE=Release -B ffi/_build ffi
!cmake --build ffi/_build
!cmake --install ffi/_build
显示代码单元输出
-- The CXX compiler identification is GNU 11.4.0
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Check for working CXX compiler: /usr/bin/c++ - skipped
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Found Python: /home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/bin/python3.10 (found suitable version "3.10.15", minimum required is "3.8") found components: Interpreter Development.Module
-- XLA include directory: /home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include
-- Configuring done (0.8s)
-- Generating done (0.0s)
-- Build files have been written to: /home/docs/checkouts/readthedocs.org/user_builds/jax/checkouts/latest/docs/ffi/_build
[ 50%] Building CXX object CMakeFiles/rms_norm.dir/rms_norm.cc.o
In file included from /home/docs/checkouts/readthedocs.org/user_builds/jax/checkouts/latest/docs/ffi/rms_norm.cc:24:
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/ffi.h:648:68: warning: ‘always_inline’ function might not be inlinable []8;;https://gcc.gnu.org/onlinedocs/gcc/Warning-Options.html#index-Wattributes-Wattributes]8;;]
648 | _ATTRIBUTE_ALWAYS_INLINE std::optional<Buffer<dtype, rank>> DecodeBuffer(
| ^~~~~~~~~~~~
In file included from /home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/ffi.h:48,
from /home/docs/checkouts/readthedocs.org/user_builds/jax/checkouts/latest/docs/ffi/rms_norm.cc:24:
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/api.h: In function ‘std::ostream& operator<<(std::ostream&, XLA_FFI_ExecutionStage)’:
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/api.h:176:1: warning: control reaches end of non-void function []8;;https://gcc.gnu.org/onlinedocs/gcc/Warning-Options.html#index-Wreturn-type-Wreturn-type]8;;]
176 | }
| ^
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/api.h: In function ‘std::ostream& operator<<(std::ostream&, XLA_FFI_AttrType)’:
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/api.h:162:1: warning: control reaches end of non-void function []8;;https://gcc.gnu.org/onlinedocs/gcc/Warning-Options.html#index-Wreturn-type-Wreturn-type]8;;]
162 | }
| ^
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/api.h: In function ‘std::ostream& operator<<(std::ostream&, XLA_FFI_DataType)’:
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/api.h:149:1: warning: control reaches end of non-void function []8;;https://gcc.gnu.org/onlinedocs/gcc/Warning-Options.html#index-Wreturn-type-Wreturn-type]8;;]
149 | }
| ^
In file included from /home/docs/checkouts/readthedocs.org/user_builds/jax/checkouts/latest/docs/ffi/rms_norm.cc:24:
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/ffi.h: In function ‘std::ostream& xla::ffi::operator<<(std::ostream&, XLA_FFI_ArgType)’:
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/ffi.h:716:1: warning: control reaches end of non-void function []8;;https://gcc.gnu.org/onlinedocs/gcc/Warning-Options.html#index-Wreturn-type-Wreturn-type]8;;]
716 | }
| ^
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/ffi.h: In function ‘std::ostream& xla::ffi::operator<<(std::ostream&, XLA_FFI_RetType)’:
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/ffi.h:791:1: warning: control reaches end of non-void function []8;;https://gcc.gnu.org/onlinedocs/gcc/Warning-Options.html#index-Wreturn-type-Wreturn-type]8;;]
791 | }
| ^
[100%] Linking CXX shared library librms_norm.so
[100%] Built target rms_norm
-- Install configuration: "Release"
-- Installing: /home/docs/checkouts/readthedocs.org/user_builds/jax/checkouts/latest/docs/ffi/librms_norm.so
有了这个编译后的库,我们现在需要通过 register_ffi_target()
函数向 XLA 注册此处理程序。此函数期望我们的处理程序(指向 C++ 函数 RmsNorm
的函数指针)包装在 PyCapsule
中。JAX 提供了一个辅助函数 pycapsule()
来帮助完成此操作
import ctypes
from pathlib import Path
import jax.extend as jex
path = next(Path("ffi").glob("librms_norm*"))
rms_norm_lib = ctypes.cdll.LoadLibrary(path)
jex.ffi.register_ffi_target(
"rms_norm", jex.ffi.pycapsule(rms_norm_lib.RmsNorm), platform="cpu")
提示
如果您熟悉传统的“自定义调用”API,值得注意的是,您还可以使用 register_ffi_target()
通过手动指定关键字参数 api_version=0
来注册自定义调用目标。 register_ffi_target()
的默认 api_version
是 1
,即我们在此处使用的新“类型化” FFI API。
一种替代方法:将处理程序公开给 Python 的一种常见替代模式是使用 nanobind 或 pybind11 来定义一个可以导入的小型 Python 扩展。对于我们这里的示例,nanobind 代码将是
#include <type_traits>
#include "nanobind/nanobind.h"
#include "xla/ffi/api/c_api.h"
namespace nb = nanobind;
template <typename T>
nb::capsule EncapsulateFfiCall(T *fn) {
// This check is optional, but it can be helpful for avoiding invalid handlers.
static_assert(std::is_invocable_r_v<XLA_FFI_Error *, T, XLA_FFI_CallFrame *>,
"Encapsulated function must be and XLA FFI handler");
return nb::capsule(reinterpret_cast<void *>(fn));
}
NB_MODULE(rms_norm, m) {
m.def("rms_norm", []() { return EncapsulateFfiCall(RmsNorm); });
}
然后,在 Python 中,我们可以使用以下命令注册此处理程序
# Assuming that we compiled a nanobind extension called `rms_norm`:
import rms_norm as rms_norm_lib
jex.ffi.register_ffi_target("rms_norm", rms_norm_lib.rms_norm(), platform="cpu")
前端代码#
现在我们已经注册了 FFI 处理程序,可以使用 ffi_call()
函数从 JAX 调用我们的 C++ 库,这非常简单
import numpy as np
def rms_norm(x, eps=1e-5):
# We only implemented the `float32` version of this function, so we start by
# checking the dtype. This check isn't strictly necessary because type
# checking is also performed by the FFI when decoding input and output
# buffers, but it can be useful to check types in Python to raise more
# informative errors.
if x.dtype != jnp.float32:
raise ValueError("Only the float32 dtype is implemented by rms_norm")
call = jex.ffi.ffi_call(
# The target name must be the same string as we used to register the target
# above in `register_custom_call_target`
"rms_norm",
# In this case, the output of our FFI function is just a single array with
# the same shape and dtype as the input. We discuss a case with a more
# interesting output type below.
jax.ShapeDtypeStruct(x.shape, x.dtype),
# The `vmap_method` parameter controls this function's behavior under `vmap`
# as discussed below.
vmap_method="broadcast_all",
)
# Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for
# the attribute `eps`. Our FFI function expects this to have the C++ `float`
# type (which corresponds to numpy's `float32` type), and it must be a
# static parameter (i.e. not a JAX array).
return call(x, eps=np.float32(eps))
# Test that this gives the same result as our reference implementation
x = jnp.linspace(-0.5, 0.5, 15).reshape((3, 5))
np.testing.assert_allclose(rms_norm(x), rms_norm_ref(x), rtol=1e-5)
此代码单元包含许多内联注释,这些注释应解释此处发生的大部分事情,但有几点值得明确强调。这里的大部分繁重工作是由 ffi_call()
函数完成的,该函数告诉 JAX 如何为特定输入集调用外部函数。需要注意的是,ffi_call()
的第一个参数必须是一个字符串,该字符串与我们在上面调用 register_custom_call_target
时使用的目标名称匹配。
任何属性(使用上面的 C++ 包装器中的 Attr
定义)都应作为关键字参数传递给 ffi_call()
。请注意,我们将 eps
显式强制转换为 np.float32
,因为我们的 FFI 库期望一个 C float
,并且我们不能在此处使用 jax.numpy
,因为这些参数必须是静态参数。
ffi_call()
的 vmap_method
参数定义了此 FFI 调用如何与 vmap()
交互,如下所述。
提示
如果您熟悉早期的“自定义调用”接口,您可能会惊讶于我们没有将问题维度(批次大小等)作为参数传递给 ffi_call()
。在此早期 API 中,后端没有接收有关输入数组的元数据的机制,但由于 FFI 将维度信息与 Buffer
对象一起包含,因此我们不再需要在降低时使用 Python 计算此信息。此更改的主要好处之一是 ffi_call()
可以开箱即用地支持一些简单的 vmap()
语义,如下所述。
使用 vmap
进行批处理#
ffi_call()
使用 vmap_method
参数,开箱即用地支持一些简单的 vmap()
语义。pure_callback()
的文档提供了关于 vmap_method
参数的更多细节,并且相同的行为适用于 ffi_call()
。
最简单的 vmap_method
是 "sequential"
。在这种情况下,当进行 vmap
处理时,ffi_call
将被重写为 scan()
,并且在 body 中调用 ffi_call
。这个实现是通用的,但它的并行化效果不太好。许多 FFI 调用提供了更有效的批处理行为,并且在一些简单的情况下,可以使用 "expand_dims"
或 "broadcast_all"
方法来公开更好的实现。
在这种情况下,由于我们只有一个输入参数,"expand_dims"
和 "broadcast_all"
实际上具有相同的行为。使用这些方法的特定假设是外部函数知道如何处理批处理维度。另一种说法是,假设在批处理输入上调用 ffi_call
的结果等于将 ffi_call
重复应用于批处理输入中的每个元素,大致如下:
ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs])
提示
请注意,当我们有多个输入参数时,事情会变得有点复杂。为简单起见,我们将在本教程中使用 "broadcast_all"
,它保证所有输入将被广播到具有相同的批处理维度,但也可以实现一个外部函数来处理 "expand_dims"
方法。pure_callback()
的文档包含一些这方面的例子。
我们实现的 rms_norm
具有适当的语义,并且开箱即用地支持使用 vmap_method="broadcast_all"
的 vmap
。
np.testing.assert_allclose(jax.vmap(rms_norm)(x), jax.vmap(rms_norm_ref)(x), rtol=1e-5)
我们可以检查 jaxpr 中 vmap()
的 rms_norm
,以确认它没有使用 scan()
进行重写。
jax.make_jaxpr(jax.vmap(rms_norm))(x)
{ lambda ; a:f32[3,5]. let
b:f32[3,5] = ffi_call[
attributes=(('eps', np.float32(1e-05)),)
custom_call_api_version=4
has_side_effect=False
input_layouts=((1, 0),)
input_output_aliases=()
legacy_backend_config=None
output_layouts=((1, 0),)
result_avals=(ShapedArray(float32[3,5]),)
target_name=rms_norm
vectorized=Deprecated
vmap_method=broadcast_all
] a
in (b,) }
使用 vmap_method="sequential"
,vmap
一个 ffi_call
将回退到 jax.lax.scan()
,并且 body 中调用 ffi_call
。
def rms_norm_sequential(x, eps=1e-5):
return jex.ffi.ffi_call(
"rms_norm",
jax.ShapeDtypeStruct(x.shape, x.dtype),
vmap_method="sequential",
)(x, eps=np.float32(eps))
jax.make_jaxpr(jax.vmap(rms_norm_sequential))(x)
{ lambda ; a:f32[3,5]. let
b:f32[3,5] = scan[
_split_transpose=False
jaxpr={ lambda ; c:f32[5]. let
d:f32[5] = ffi_call[
attributes=(('eps', np.float32(1e-05)),)
custom_call_api_version=4
has_side_effect=False
input_layouts=((0,),)
input_output_aliases=()
legacy_backend_config=None
output_layouts=((0,),)
result_avals=(ShapedArray(float32[5]),)
target_name=rms_norm
vectorized=Deprecated
vmap_method=sequential
] c
in (d,) }
length=3
linear=(False,)
num_carry=0
num_consts=0
reverse=False
unroll=1
] a
in (b,) }
如果您的外部函数提供了此简单的 vmap_method
参数不支持的高效批处理规则,则也可以使用实验性的 custom_vmap
接口定义更灵活的自定义 vmap
规则,但也值得在 JAX 问题跟踪器上打开一个问题来描述您的用例。
微分#
与批处理不同,ffi_call()
不为外部函数的自动微分 (AD) 提供任何默认支持。就 JAX 而言,外部函数是一个黑盒子,无法检查以确定微分时的适当行为。因此,定义自定义导数规则是 ffi_call()
用户的责任。
有关自定义导数规则的更多详细信息,请参阅自定义导数教程,但用于实现外部函数微分的最常见模式是定义一个 custom_vjp()
,它本身会调用一个外部函数。在这种情况下,我们实际定义了两个新的 FFI 调用:
rms_norm_fwd
返回两个输出:(a) “原始”结果,和 (b) 在反向传递中使用的“残差”。rms_norm_bwd
接受残差和输出余切,并返回输入余切。
我们不会深入探讨 RMS 归一化反向传递的细节,但请查看C++ 源代码,了解这些函数如何在后端实现。这里要强调的重点是,计算出的“残差”的形状与原始输出不同,因此,在 ffi_call()
中对 res_norm_fwd
的调用,输出类型有两个形状不同的元素。
此自定义导数规则可以如下连接:
jex.ffi.register_ffi_target(
"rms_norm_fwd", jex.ffi.pycapsule(rms_norm_lib.RmsNormFwd), platform="cpu"
)
jex.ffi.register_ffi_target(
"rms_norm_bwd", jex.ffi.pycapsule(rms_norm_lib.RmsNormBwd), platform="cpu"
)
def rms_norm_fwd(x, eps=1e-5):
y, res = jex.ffi.ffi_call(
"rms_norm_fwd",
(
jax.ShapeDtypeStruct(x.shape, x.dtype),
jax.ShapeDtypeStruct(x.shape[:-1], x.dtype),
),
vmap_method="broadcast_all",
)(x, eps=np.float32(eps))
return y, (res, x)
def rms_norm_bwd(eps, res, ct):
del eps
res, x = res
assert res.shape == ct.shape[:-1]
assert x.shape == ct.shape
return (
jex.ffi.ffi_call(
"rms_norm_bwd",
jax.ShapeDtypeStruct(ct.shape, ct.dtype),
vmap_method="broadcast_all",
)(res, x, ct),
)
rms_norm = jax.custom_vjp(rms_norm, nondiff_argnums=(1,))
rms_norm.defvjp(rms_norm_fwd, rms_norm_bwd)
# Check that this gives the right answer when compared to the reference version
ct_y = jnp.ones_like(x)
np.testing.assert_allclose(
jax.vjp(rms_norm, x)[1](ct_y), jax.vjp(rms_norm_ref, x)[1](ct_y), rtol=1e-5
)
此时,我们可以透明地将新的 rms_norm
函数用于许多 JAX 应用程序,它将在标准的 JAX 函数转换(如 vmap()
和 grad()
)下进行适当的转换。此示例不支持的一件事是前向模式 AD(例如,jax.jvp()
),因为 custom_vjp()
仅限于反向模式。JAX 目前没有公开的 API 来同时自定义前向模式和反向模式 AD,但此类 API 已在路线图上,因此如果您在实践中遇到此限制,请打开一个问题描述您的用例。
此示例不支持的另一个 JAX 功能是高阶 AD。可以通过将上面的 res_norm_bwd
函数包装在 jax.custom_jvp()
或 jax.custom_vjp()
装饰器中来解决此问题,但我们不会在此处深入探讨该高级用例的详细信息。
GPU 上的 FFI 调用#
到目前为止,我们仅与在 CPU 上运行的外部函数进行交互,但 JAX 的 FFI 也支持对 GPU 代码的调用。由于此文档页面是在无法访问 GPU 的计算机上自动生成的,因此我们无法在此处执行任何特定于 GPU 的示例,但我们将介绍关键点。
在为 CPU 定义 FFI 包装器时,我们使用的函数签名是:
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::F32> x,
ffi::ResultBuffer<ffi::F32> y)
要更新此签名以与 CUDA 内核交互,此签名变为:
ffi::Error RmsNormImpl(cudaStream_t stream, float eps,
ffi::Buffer<ffi::F32> x,
ffi::ResultBuffer<ffi::F32> y)
并且处理程序定义已更新为在其绑定中包含 Ctx
:
XLA_FFI_DEFINE_HANDLER(
RmsNorm, RmsNormImpl,
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<cudaStream_t>>()
.Attr<float>("eps")
.Arg<ffi::Buffer<ffi::F32>>() // x
.Ret<ffi::Buffer<ffi::F32>>() // y
);
然后,RmsNormImpl
可以使用 CUDA 流来启动 CUDA 内核。
在前端,注册代码将被更新以指定适当的平台:
jex.ffi.register_ffi_target(
"rms_norm_cuda", rms_norm_lib_cuda.rms_norm(), platform="CUDA"
)
支持多个平台#
为了支持在 GPU 和 CPU 上运行我们的 rms_norm
函数,我们可以将上面的实现与 jax.lax.platform_dependent()
函数结合使用:
def rms_norm_cross_platform(x, eps=1e-5):
assert x.dtype == jnp.float32
out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)
def impl(target_name):
return lambda x: jex.ffi.ffi_call(
target_name,
out_type,
vmap_method="broadcast_all",
)(x, eps=np.float32(eps))
return jax.lax.platform_dependent(x, cpu=impl("rms_norm"), cuda=impl("rms_norm_cuda"))
np.testing.assert_allclose(rms_norm_cross_platform(x), rms_norm_ref(x), rtol=1e-5)
此版本的函数将根据运行时平台调用适当的 FFI 目标。
顺便一提,值得注意的是,虽然 jaxpr 和降低的 HLO 都包含对两个 FFI 目标的引用:
jax.make_jaxpr(rms_norm_cross_platform)(x)
{ lambda ; a:f32[3,5]. let
b:i32[] = platform_index[has_default=False platforms=(('cpu',), ('cuda',))]
c:i32[] = clamp 0 b 1
d:f32[3,5] = cond[
branches=(
{ lambda ; e:f32[3,5]. let
f:f32[3,5] = ffi_call[
attributes=(('eps', np.float32(1e-05)),)
custom_call_api_version=4
has_side_effect=False
input_layouts=((1, 0),)
input_output_aliases=()
legacy_backend_config=None
output_layouts=((1, 0),)
result_avals=(ShapedArray(float32[3,5]),)
target_name=rms_norm
vectorized=Deprecated
vmap_method=broadcast_all
] e
in (f,) }
{ lambda ; g:f32[3,5]. let
h:f32[3,5] = ffi_call[
attributes=(('eps', np.float32(1e-05)),)
custom_call_api_version=4
has_side_effect=False
input_layouts=((1, 0),)
input_output_aliases=()
legacy_backend_config=None
output_layouts=((1, 0),)
result_avals=(ShapedArray(float32[3,5]),)
target_name=rms_norm_cuda
vectorized=Deprecated
vmap_method=broadcast_all
] g
in (h,) }
)
] c a
in (d,) }
print(jax.jit(rms_norm_cross_platform).lower(x).as_text().strip())
module @jit_rms_norm_cross_platform attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<3x5xf32>) -> (tensor<3x5xf32> {jax.result_info = ""}) {
%c = stablehlo.constant dense<0> : tensor<i32>
%c_0 = stablehlo.constant dense<0> : tensor<i32>
%c_1 = stablehlo.constant dense<1> : tensor<i32>
%0 = stablehlo.clamp %c_0, %c, %c_1 : tensor<i32>
%1 = "stablehlo.case"(%0) ({
%2 = stablehlo.custom_call @rms_norm(%arg0) {backend_config = "", mhlo.backend_config = {eps = 9.99999974E-6 : f32}, operand_layouts = [dense<[1, 0]> : tensor<2xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>]} : (tensor<3x5xf32>) -> tensor<3x5xf32>
stablehlo.return %2 : tensor<3x5xf32>
}, {
%2 = stablehlo.custom_call @rms_norm_cuda(%arg0) {backend_config = "", mhlo.backend_config = {eps = 9.99999974E-6 : f32}, operand_layouts = [dense<[1, 0]> : tensor<2xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>]} : (tensor<3x5xf32>) -> tensor<3x5xf32>
stablehlo.return %2 : tensor<3x5xf32>
}) : (tensor<i32>) -> tensor<3x5xf32>
return %1 : tensor<3x5xf32>
}
}
但在编译函数时,已选择适当的 FFI:
print(jax.jit(rms_norm_cross_platform).lower(x).as_text(dialect="hlo").strip())
HloModule jit_rms_norm_cross_platform, entry_computation_layout={(f32[3,5]{1,0})->f32[3,5]{1,0}}
ENTRY main.3 {
Arg_0.1 = f32[3,5]{1,0} parameter(0)
ROOT custom-call.2 = f32[3,5]{1,0} custom-call(Arg_0.1), custom_call_target="rms_norm", operand_layout_constraints={f32[3,5]{1,0}}, api_version=API_VERSION_TYPED_FFI
}
并且使用 jax.lax.platform_dependent()
不会有任何运行时开销,并且编译的程序不会包含对不可用的 FFI 目标的任何引用。
高级主题#
本教程涵盖了使用 JAX 的 FFI 启动和运行所需的大部分基本步骤,但高级用例可能需要更多功能。我们将把这些主题留给未来的教程,但这里有一些可能有用的参考资料:
支持多种数据类型:在本教程的示例中,我们仅限于支持
float32
输入和输出,但许多用例需要支持多种不同的输入类型。一种处理方法是为所有支持的输入类型注册不同的 FFI 目标,然后使用 Python 根据输入类型为jax.extend.ffi.ffi_call()
选择适当的目标。但是,根据支持情况的组合,这种方法可能会很快变得难以管理。因此,也可以定义 C++ 处理程序来接受ffi::AnyBuffer
而不是ffi::Buffer<Dtype>
。然后,输入缓冲区将包含一个element_type()
方法,该方法可用于在后端定义适当的数据类型分发逻辑。分片:当在
jit()
中使用 JAX 的自动数据相关并行时,使用ffi_call()
实现的 FFI 调用没有足够的信息来进行适当的分片,因此它们会导致将输入复制到所有设备,并且 FFI 调用会在每个设备上的完整数组上执行。为了解决这个限制,可以使用shard_map()
或custom_partitioning()
。有状态的外部函数:也可以使用 FFI 来封装具有关联状态的函数。XLA 测试套件中包含一个低级别示例,未来的教程将包含更多详细信息。