Pallas 更新日志#
这是特定于 jax.experimental.pallas
的更改列表。有关整体 JAX 更改日志,请参见此处。
在 jax 0.5.0 中发布#
新功能
添加了对 TPU 上
jax.experimental.pallas.debug_print()
的向量支持。
在 jax 0.4.37 中发布#
新功能
为 Triton 后端的
dot
降低添加了对DotAlgorithmPreset
精度参数的支持。
在 jax 0.4.36 中发布 (2024 年 12 月 6 日)#
在 jax 0.4.35 中发布 (2024 年 10 月 22 日)#
移除
移除了之前已弃用的别名
jax.experimental.pallas.tpu.CostEstimate
和jax.experimental.tpu.run_scoped()
。现在它们都可以在jax.experimental.pallas
中找到。
新功能
添加了一个成本估算工具
pl.estimate_cost()
,用于从 JAX 参考函数自动构建内核成本估算。
随 jax 0.4.34 版本发布(2024 年 10 月 4 日)#
变更
jax.experimental.pallas.debug_print()
不再要求所有参数都是标量。参数的限制是后端特定的:目前仅在使用 Triton 时,GPU 上支持非标量参数。jax.experimental.pallas.BlockSpec
不再支持之前已弃用的参数顺序,即index_map
在block_shape
之前。
弃用
为了避免与
jax.experimental.pallas.mosaic_gpu
产生歧义,jax.experimental.pallas.gpu
子模块已被弃用。要使用 Triton 后端,请导入jax.experimental.pallas.triton
。
新功能
jax.experimental.pallas.pallas_call()
现在接受scratch_shapes
,这是一个 PyTree,用于指定内核所需的后端特定的临时对象,例如缓冲区、同步原语等。当使用
pltpu.enable_runtime_assert(True)
上下文管理器调用 pallas_call 时,现在可以使用checkify.check()
插入运行时断言。
随 jax 0.4.33 版本发布(2024 年 9 月 16 日)#
随 jax 0.4.32 版本发布(2024 年 9 月 11 日)#
变更
内核函数不允许关闭常量。相反,所有需要的数组都必须作为输入传递,并具有正确的块规范 (#22746)。
新功能
改进了索引映射函数签名错误的错误消息,包括索引映射的名称和源位置。
随 jax 0.4.31 版本发布(2024 年 7 月 29 日)#
变更
jax.experimental.pallas.BlockSpec
现在期望block_shape
在index_map
之前 传递。旧的参数顺序已被弃用,并将在未来的版本中删除。jax.experimental.pallas.GridSpec
不再具有in_specs_tree
和out_specs_tree
字段,并且in_specs
和out_specs
树现在将值存储为 BlockSpec 的 pytrees。之前,in_specs
和out_specs
是扁平化的 (#22552)。jax.experimental.pallas.GridSpec
的方法compute_index
已被删除,因为它是一个私有方法。同样,get_grid_mapping
和unzip_dynamic_bounds
已从BlockSpec
中删除 (#22593)。修复了解释模式,使其可以与涉及填充的 BlockSpec 一起使用 (#22275)。解释模式中的填充将使用 NaN,以帮助调试越界错误,但是当在自定义内核模式下运行时,此行为不存在,并且不应依赖它。
以前,可以导入许多旨在设为私有的 API,例如
jax.experimental.pallas.pallas
。现在不再可能了。
新功能
添加了 BlockSpec 的文档:网格和块规范。
改进了
jax.experimental.pallas.pallas_call()
API 的错误消息。为 TPU 添加了
lax.shift_right_arithmetic
(#22279) 和lax.erf_inv
(#22310) 的降级规则。为 Pallas TPU 自定义内核添加了对形状多态性的初始支持
(#22084).添加了对 checkify 的 TPU 支持。 (#22480)
当块大小与 TPU 要求不匹配时,添加了更清晰的错误消息。以前,错误来自 Mosaic 后端,并且没有有用的 Python 堆栈跟踪。
添加了对具有 1D 块的 TPU 降级的支持,并放宽了对至少 2 个维度的块大小的要求:最后 2 个维度必须分别可以被 8 和 128 整除,除非它们跨越相应的整个数组维度。以前,只有当最后两个维度的块维度分别小于 8 和 128 时,才允许跨越整个数组的块维度。
随 JAX 0.4.30 版本发布(2024 年 6 月 18 日)#
新功能
在解释模式下添加了对
jax.experimental.pallas.pallas_call()
的 checkify 支持 (#21862)。改进了对 TPU 内核的 PRNG 密钥的支持 (#21773)。