Pallas 更新日志#

这是 jax.experimental.pallas 的特定更改列表。有关整体 JAX 更改日志,请参见此处

随 jax 0.4.37 版本发布#

  • 新功能

    • 为 Triton 后端的 dot 降低添加了对 DotAlgorithmPreset 精度参数的支持。

随 jax 0.4.36 版本发布 (2024 年 12 月 6 日)#

随 jax 0.4.35 版本发布 (2024 年 10 月 22 日)#

  • 移除

    • 移除了之前已弃用的别名 jax.experimental.pallas.tpu.CostEstimatejax.experimental.tpu.run_scoped()。两者现在都可以在 jax.experimental.pallas 中找到。

  • 新功能

    • 添加了一个成本估算工具 pl.estimate_cost(),用于从 JAX 参考函数自动构建内核成本估算。

随 jax 0.4.34 版本发布(2024 年 10 月 4 日)#

随 jax 0.4.33 版本发布(2024 年 9 月 16 日)#

随 jax 0.4.32 版本发布(2024 年 9 月 11 日)#

  • 变更

    • 内核函数不允许闭包常量。相反,所有需要的数组都必须作为输入传递,并具有适当的块规范(#22746)。

  • 新功能

    • 改进了索引映射函数签名错误的错误消息,以包含索引映射的名称和源位置。

随 jax 0.4.31 版本发布(2024 年 7 月 29 日)#

  • 变更

    • jax.experimental.pallas.BlockSpec 现在要求在 index_map 之前传递 block_shape。旧的参数顺序已弃用,将在未来的版本中删除。

    • jax.experimental.pallas.GridSpec 不再具有 in_specs_treeout_specs_tree 字段,并且 in_specsout_specs 树现在将值存储为 BlockSpec 的 pytree。以前,in_specsout_specs 被展平了(#22552)。

    • jax.experimental.pallas.GridSpeccompute_index 方法已被删除,因为它是私有的。类似地,get_grid_mappingunzip_dynamic_bounds 已从 BlockSpec 中删除(#22593)。

    • 修复了解释模式,使其可以与涉及填充的 BlockSpec 一起使用(#22275)。解释模式下的填充将使用 NaN,以帮助调试越界错误,但此行为在自定义内核模式下运行时不存在,不应依赖此行为。

    • 以前,可以导入许多旨在私有的 API,例如 jax.experimental.pallas.pallas。现在不再允许这样做。

  • 新功能

    • 添加了 BlockSpec 的文档:网格和块规范

    • 改进了 jax.experimental.pallas.pallas_call() API 的错误消息。

    • 为 TPU 添加了 lax.shift_right_arithmetic#22279)和 lax.erf_inv#22310)的降阶规则。

    • 为 Pallas TPU 自定义内核添加了对形状多态性的初步支持。
      (#22084).

    • 添加了对 checkify 的 TPU 支持。(#22480

    • 当块大小与 TPU 要求不匹配时,添加了更清晰的错误消息。以前,错误来自 Mosaic 后端,并且没有有用的 Python 堆栈跟踪。

    • 添加了对具有 1D 块的 TPU 降阶的支持,并放宽了对至少具有 2 个维度的块大小的要求:最后 2 个维度必须分别可以被 8 和 128 整除,除非它们跨越整个相应的数组维度。以前,仅当最后两个维度中的块维度分别小于 8 和 128 时,才允许跨越整个数组的块维度。

随 JAX 0.4.30 版本发布(2024 年 6 月 18 日)#