Pallas 变更日志#
这是特定于 jax.experimental.pallas
的更改列表。有关 JAX 的整体变更日志,请参阅 此处。
随 jax 0.4.32 发布#
变更
内核函数不允许闭包常量。相反,所有需要的数组都必须作为输入传递,并带有适当的块规格 (#22746).
弃用
新功能
改进了索引映射函数签名中错误的错误消息,以包括索引映射的名称和源位置。
随 jax 0.4.31 发布 (2024 年 7 月 29 日)#
变更
jax.experimental.pallas.BlockSpec
现在期望在index_map
之前传递block_shape
。旧的参数顺序已弃用,将在将来的版本中删除。jax.experimental.pallas.GridSpec
不再拥有in_specs_tree
和out_specs_tree
字段,并且in_specs
和out_specs
树现在将值存储为 BlockSpec 的 pytrees。以前,in_specs
和out_specs
是扁平化的 (#22552).方法
compute_index
在jax.experimental.pallas.GridSpec
中已被移除,因为它属于私有方法。类似地,get_grid_mapping
和unzip_dynamic_bounds
已从BlockSpec
中移除 (#22593)。修正了解释模式以支持包含填充的 BlockSpec (#22275)。解释模式中的填充将使用 NaN,以便于调试越界错误,但这种行为在自定义内核模式下不存在,也不应依赖它。
之前,可以导入许多旨在保持私有的 API,例如
jax.experimental.pallas.pallas
。现在已不再允许。
弃用
新功能
添加了 BlockSpec 的文档:网格和 BlockSpec。
改进了
jax.experimental.pallas.pallas_call()
API 的错误消息。添加了 TPU 的
lax.shift_right_arithmetic
(#22279) 和lax.erf_inv
(#22310) 的降低规则。为 Pallas TPU 自定义内核添加了对形状多态性的初步支持。
(#22084).添加了对 checkify 的 TPU 支持。(#22480)。
当块大小与 TPU 要求不匹配时,添加了更清晰的错误消息。之前,错误来自 Mosaic 后端,没有提供有用的 Python 堆栈跟踪。
添加了对使用 1D 块的 TPU 降低的支持,并放宽了至少有两个维度的块大小的要求:最后两个维度必须分别能被 8 和 128 整除,除非它们跨越了相应的数组维度。之前,仅当最后两个维度的块维度小于 8 和 128 时,才能允许跨越整个数组的块维度。
与 JAX 0.4.30 (2024 年 6 月 18 日) 一起发布#
新功能
添加了对解释模式下的
jax.experimental.pallas.pallas_call()
的 checkify 支持 (#21862)。改进了对 TPU 内核的 PRNG 密钥的支持 (#21773)。