JAX:高性能数组计算

内容

JAX:高性能数组计算#

JAX 是一个用于加速器导向的数组计算和程序转换的 Python 库,专为高性能数值计算和大规模机器学习而设计。

熟悉的 API

JAX 提供了熟悉的 NumPy 风格的 API,以便研究人员和工程师轻松上手。

转换

JAX 包括用于编译、批处理、自动微分和并行化的可组合函数转换。

随处运行

相同的代码可以在多个后端上执行,包括 CPU、GPU 和 TPU

入门
JAX 入门
用户指南
用户指南
开发者笔记
开发者笔记

如果您想训练神经网络,请使用 Flax 并从其教程开始。对于一个基于 JAX 的端到端转换器库,请参阅 MaxText

生态系统#

JAX 本身范围狭窄,专注于高效的数组操作和程序转换。围绕 JAX 构建了一个不断发展的机器学习和数值计算工具生态系统;以下是其中的一小部分示例

神经网络

优化器和求解器

其他工具

概率编程

概率建模

物理和仿真

大型语言模型

许多其他基于 JAX 的库也已被开发出来;社区运营的 Awesome JAX 页面维护着一个最新的列表。