Pallas 变更日志#

这是特定于 jax.experimental.pallas 的更改列表。有关 JAX 的整体变更日志,请参阅 此处

随 jax 0.4.32 发布#

  • 变更

    • 内核函数不允许闭包常量。相反,所有需要的数组都必须作为输入传递,并带有适当的块规格 (#22746).

  • 弃用

  • 新功能

    • 改进了索引映射函数签名中错误的错误消息,以包括索引映射的名称和源位置。

随 jax 0.4.31 发布 (2024 年 7 月 29 日)#

  • 变更

    • jax.experimental.pallas.BlockSpec 现在期望在 index_map 之前传递 block_shape。旧的参数顺序已弃用,将在将来的版本中删除。

    • jax.experimental.pallas.GridSpec 不再拥有 in_specs_treeout_specs_tree 字段,并且 in_specsout_specs 树现在将值存储为 BlockSpec 的 pytrees。以前,in_specsout_specs 是扁平化的 (#22552).

    • 方法 compute_indexjax.experimental.pallas.GridSpec 中已被移除,因为它属于私有方法。类似地,get_grid_mappingunzip_dynamic_bounds 已从 BlockSpec 中移除 (#22593)。

    • 修正了解释模式以支持包含填充的 BlockSpec (#22275)。解释模式中的填充将使用 NaN,以便于调试越界错误,但这种行为在自定义内核模式下不存在,也不应依赖它。

    • 之前,可以导入许多旨在保持私有的 API,例如 jax.experimental.pallas.pallas。现在已不再允许。

  • 弃用

  • 新功能

    • 添加了 BlockSpec 的文档:网格和 BlockSpec

    • 改进了 jax.experimental.pallas.pallas_call() API 的错误消息。

    • 添加了 TPU 的 lax.shift_right_arithmetic (#22279) 和 lax.erf_inv (#22310) 的降低规则。

    • 为 Pallas TPU 自定义内核添加了对形状多态性的初步支持。
      (#22084).

    • 添加了对 checkify 的 TPU 支持。(#22480)。

    • 当块大小与 TPU 要求不匹配时,添加了更清晰的错误消息。之前,错误来自 Mosaic 后端,没有提供有用的 Python 堆栈跟踪。

    • 添加了对使用 1D 块的 TPU 降低的支持,并放宽了至少有两个维度的块大小的要求:最后两个维度必须分别能被 8 和 128 整除,除非它们跨越了相应的数组维度。之前,仅当最后两个维度的块维度小于 8 和 128 时,才能允许跨越整个数组的块维度。

与 JAX 0.4.30 (2024 年 6 月 18 日) 一起发布#