Pallas:一种 JAX 内核语言

Pallas:一种 JAX 内核语言#

Pallas 是 JAX 的一个扩展,它允许为 GPU 和 TPU 编写自定义内核。它旨在提供对生成代码的细粒度控制,同时结合 JAX 追踪的高级人体工程学和 jax.numpy API。

本节包含使用 Pallas 的教程、指南和示例。另请参阅 jax.experimental.pallas 模块 API 文档。

警告

Pallas 处于实验阶段,并且经常发生变化。请参阅 Pallas 变更日志 以了解最近的更改。

您可以预期会遇到错误和未实现的情况,例如,当降低需要模拟的高级 JAX 概念时,或者仅仅因为 Pallas 仍处于开发阶段。