变更日志#
建议访问 此处 查看。有关实验性 Pallas API 的具体更改,请参阅 Pallas 变更日志。
jax 0.4.34#
删除
jax.xla_computation
已被删除。自 0.4.30 JAX 版本弃用以来已过去 3 个月。请使用 AOT API 获取与jax.xla_computation
相同的功能。jax.xla_computation(fn)(*args, **kwargs)
可以替换为jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')
。您还可以使用
.out_info
属性jax.stages.Lowered
获取输出信息(例如树结构、形状和数据类型)。对于跨后端降低,您可以将
jax.xla_computation(fn, backend='tpu')(*args, **kwargs)
替换为jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')
。
jax.ShapeDtypeStruct
不再接受named_shape
参数。该参数仅供xmap
使用,后者在 0.4.31 中已删除。
jax 0.4.33 (2024年9月16日)#
这是基于 jax 0.4.32 的修补程序版本,修复了该版本中发现的两个错误。
在 JAX 0.4.32 版本固定的 libtpu 版本中发现了一个仅限 TPU 的数据损坏错误,该错误仅在同一作业中存在多个 TPU 切片时才会出现,例如,如果在多个 v5e 切片上进行训练。此版本通过固定 libtpu
的修复版本来解决此问题。
此版本修复了 CPU 上 F64 tanh 的不准确结果 (#23590)。
jax 0.4.32 (2024年9月11日)#
注意:此版本已从 PyPi 中移除,因为存在 TPU 上的数据损坏错误。有关更多详细信息,请参阅 0.4.33 版本说明。
新功能
添加了
jax.extend.ffi.ffi_call()
和jax.extend.ffi.ffi_lowering()
以支持使用新的 外部函数接口 (FFI) 从 JAX 与自定义 C++ 和 CUDA 代码进行交互。
更改
jax_pmap_no_rank_reduction
标志默认设置为True
。pmap 结果上的 array[0] 现在会引入重新整形(请改用 array[0:1])。
每个分片形状(可通过 jax_array.addressable_shards 或 jax_array.addressable_data(0) 访问)现在具有前导 (1, …)。相应地更新直接访问分片的代码。每个分片形状的秩现在与全局形状的秩匹配,这与 jit 的行为相同。这避免了将结果从 pmap 传递到 jit 时代价高昂的重新整形。
jax_enable_memories
标志默认设置为True
。jax.numpy
现在支持 Python 数组 API 标准的 v2023.12 版本。有关更多信息,请参阅 Python 数组 API 标准。CPU 后端上的计算现在可以在更多情况下异步调度。以前,非并行计算始终同步调度。您可以通过设置
jax.config.update('jax_cpu_enable_async_dispatch', False)
来恢复旧行为。添加了新的
jax.process_indices()
函数来替换在 JAX v0.2.13 中已弃用的jax.host_ids()
函数。为了与
numpy.fabs
的行为保持一致,jax.numpy.fabs
已修改为不再支持complex dtypes
。jax.tree_util.register_dataclass
现在检查data_fields
和meta_fields
是否包含所有具有init=True
的数据类字段,并且仅包含它们,如果nodetype
是一个数据类。几个
jax.numpy
函数现在具有完整的ufunc
接口,包括add
、multiply
、bitwise_and
、bitwise_or
、bitwise_xor
、logical_and
、logical_and
和logical_and
。在未来的版本中,我们计划将其扩展到其他 ufunc。添加了
jax.lax.optimization_barrier()
,允许用户阻止编译器优化(例如公共子表达式消除)并控制调度。
重大更改
MHLO MLIR 方言(
jax.extend.mlir.mhlo
)已移除。请改用stablehlo
方言。
弃用
jax.numpy.clip()
和jax.numpy.hypot()
的复杂输入不再允许,此前自 JAX v0.4.27 起已弃用。已弃用以下 API
jax.lib.xla_bridge.xla_client
:请直接使用jax.lib.xla_client
。jax.lib.xla_bridge.get_backend
:请使用jax.extend.backend.get_backend()
。jax.lib.xla_bridge.default_backend
:请使用jax.extend.backend.default_backend()
。
jax.experimental.array_api
模块已弃用,并且不再需要导入它来使用数组 API。jax.numpy
直接支持数组 API;有关更多信息,请参阅 Python 数组 API 标准。内部实用程序
jax.core.check_eqn
、jax.core.check_type
和jax.core.check_valid_jaxtype
现已弃用,将来将被移除。jax.numpy.round_
已弃用,因为 NumPy 2.0 中已移除相应的 API。请改用jax.numpy.round()
。将 DLPack 胶囊传递给
jax.dlpack.from_dlpack()
已弃用。jax.dlpack.from_dlpack()
的参数应来自实现__dlpack__
协议的其他框架的数组。
jaxlib 0.4.32 (2024 年 9 月 11 日)#
注意:此版本已从 PyPi 中移除,因为存在 TPU 上的数据损坏错误。有关更多详细信息,请参阅 0.4.33 版本说明。
重大更改
添加了密封的 CUDA 支持。密封的 CUDA 使用特定可下载版本的 CUDA,而不是用户本地安装的 CUDA。Bazel 将下载 CUDA、CUDNN 和 NCCL 发行版,然后在各种 Bazel 目标中使用 CUDA 库和工具作为依赖项。这使得 JAX 及其支持的 CUDA 版本能够构建更可重复的版本。
更改
添加了 SparseCore 分析。
JAX 现在支持分析 TPUv5p 芯片上的 SparseCore。这些跟踪将在 Tensorboard Profiler 的 TraceViewer 中可见。
jax 0.4.31 (2024 年 7 月 29 日)#
删除
xmap 已删除。请使用
shard_map()
作为替换。
更改
最低 CuDNN 版本为 v9.1。这在以前的版本中也是如此,但我们现在正式声明此版本约束。
最低 Python 版本现在为 3.10。3.10 将保持最低支持版本,直到 2025 年 7 月。
最低 NumPy 版本现在为 1.24。NumPy 1.24 将保持最低支持版本,直到 2024 年 12 月。
最低 SciPy 版本现在为 1.10。SciPy 1.10 将保持最低支持版本,直到 2025 年 1 月。
jax.numpy.ceil()
、jax.numpy.floor()
和jax.numpy.trunc()
现在返回与输入相同数据类型的输出,即不再将整数或布尔输入提升为浮点数。libdevice.10.bc
不再与 CUDA 轮子捆绑在一起。它必须作为本地 CUDA 安装的一部分或通过 NVIDIA 的 CUDA pip 轮子进行安装。jax.experimental.pallas.BlockSpec
现在期望在index_map
之前传递block_shape
。旧的参数顺序已弃用,将在未来的版本中移除。更新了 GPU 设备的 repr,使其与 TPU/CPU 更一致。例如,
cuda(id=0)
现在将是CudaDevice(id=0)
。添加了
device
属性和to_device
方法到jax.Array
,作为 JAX 的 数组 API 支持的一部分。
弃用
移除了一些先前已弃用的与多态形状相关的内部 API。从
jax.core
:移除canonicalize_shape
、dimension_as_value
、definitely_equal
和symbolic_equal_dim
。HLO 降级规则不再应将单例 ir.Values 包装在元组中。而是返回未包装的单例 ir.Values。对包装值的支持将在 JAX 的未来版本中移除。
jax.experimental.jax2tf.convert()
且native_serialization=False
或enable_xla=False
现在已弃用,并且此支持将在未来的版本中移除。自 JAX 0.4.16(2023 年 9 月)起,本机序列化已成为默认值。先前已弃用的函数
jax.random.shuffle
已移除;请改用jax.random.permutation
且independent=True
。
jaxlib 0.4.31 (2024 年 7 月 29 日)#
错误修复
修复了一个错误,该错误会导致 jit 分派的快速路径错误处理 jit 的负 static_argnums。
修复了一个错误,该错误会导致奇异矩阵批次的三角求解产生无意义的有限值,而不是无穷大或 NaN(#3589、#15429)。
jax 0.4.30 (2024 年 6 月 18 日)#
更改
JAX 支持 ml_dtypes >= 0.2。在 0.4.29 版本中,ml_dtypes 版本已提升至 0.4.0,但在本版本中已回滚,以便 TensorFlow 和 JAX 的用户有更多时间迁移到较新的 TensorFlow 版本。
jax.experimental.mesh_utils
现在可以为 TPU v5e 创建高效的网格。jax 现在直接依赖于 jaxlib。此更改由 CUDA 插件切换启用:不再有多个 jaxlib 变体。您可以使用
pip install jax
安装仅 CPU 的 jax,无需任何额外内容。添加了一个用于导出和序列化 JAX 函数的 API。这曾经存在于
jax.experimental.export
(即将弃用)中,现在将位于jax.export
中。请参阅 文档。
弃用
内部漂亮打印工具
jax.core.pp_*
已弃用,将在未来的版本中移除。跟踪器的散列已弃用,将在未来的 JAX 版本中导致
TypeError
。这以前是这种情况,但在最近几个 JAX 版本中出现了一个意外的回归。jax.experimental.export
已弃用。请改用jax.export
。请参阅 迁移指南。在大多数情况下,传递数组代替数据类型现在已弃用;例如,对于数组
x
和y
,x.astype(y)
将引发警告。要使其静音,请使用x.astype(y.dtype)
。jax.xla_computation
已弃用,将在未来的版本中移除。请使用 AOT API 来获得与jax.xla_computation
相同的功能。jax.xla_computation(fn)(*args, **kwargs)
可以替换为jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')
。您还可以使用
.out_info
属性jax.stages.Lowered
获取输出信息(例如树结构、形状和数据类型)。对于跨后端降低,您可以将
jax.xla_computation(fn, backend='tpu')(*args, **kwargs)
替换为jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')
。
jaxlib 0.4.30 (2024 年 6 月 18 日)#
对单片 CUDA jaxlib 的支持已被删除。您必须使用基于插件的安装(
pip install jax[cuda12]
或pip install jax[cuda12_local]
)。
jax 0.4.29 (2024 年 6 月 10 日)#
更改
我们预计这将是支持单片 CUDA jaxlib 的 JAX 和 jaxlib 的最后一个版本。未来的版本将使用 CUDA 插件 jaxlib(例如
pip install jax[cuda12]
)。JAX 现在需要 ml_dtypes 版本 0.4.0 或更高版本。
移除对
jax.experimental.export
API 旧用法的向后兼容性支持。现在无法再使用from jax.experimental.export import export
,而是应该使用from jax.experimental import export
。已移除的功能自 0.4.24 起已弃用。添加了
is_leaf
参数到jax.tree.all()
和jax.tree_util.tree_all()
。
弃用
jax.sharding.XLACompatibleSharding
已弃用。请使用jax.sharding.Sharding
。jax.experimental.Exported.in_shardings
已重命名为jax.experimental.Exported.in_shardings_hlo
。out_shardings
也一样。旧名称将在 3 个月后移除。移除了一些之前已弃用的 API
jax.numpy.linalg.matrix_rank()
的tol
参数即将弃用,并将很快被移除。请使用rtol
代替。jax.numpy.linalg.pinv()
的rcond
参数即将弃用,并将很快被移除。请使用rtol
代替。已移除弃用的
jax.config
子模块。要配置 JAX,请使用import jax
,然后通过jax.config
引用配置对象。jax.random
API 不再接受批处理键,之前一些 API 无意中接受了。今后,我们建议在这些情况下显式使用jax.vmap()
。在
jax.scipy.special.beta()
中,x
和y
参数已重命名为a
和b
,以与其他beta
API 保持一致。
新功能
添加了
jax.experimental.Exported.in_shardings_jax()
,用于根据存储在Exported
对象中的 HloShardings 构造可在 JAX API 中使用的分片。
jaxlib 0.4.29 (2024年6月10日)#
错误修复
修复了一个 XLA 错误地分片了一些连接操作的错误,该错误表现为累积归约的输出不正确 (#21403)。
修复了 XLA:CPU 错误编译某些矩阵乘法融合的问题 (https://github.com/openxla/xla/pull/13301)。
修复了 GPU 上的编译器崩溃问题 (https://github.com/google/jax/issues/21396)。
弃用
jax.tree.map(f, None, non-None)
现在会发出DeprecationWarning
,并在 JAX 的未来版本中引发错误。None
仅是其自身的树前缀。要保留当前行为,您可以要求jax.tree.map
将None
视为叶值,方法是编写:jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)
。
jax 0.4.28 (2024年5月9日)#
错误修复
撤销了对
make_jaxpr
的更改,该更改导致 Equinox 出现问题 (#21116)。
弃用和移除
jax.numpy.sort()
和jax.numpy.argsort()
的kind
参数现已移除。请使用stable=True
或stable=False
代替。已从
jax.experimental.pallas.gpu
模块中移除get_compute_capability
。请改用 GPU 设备的compute_capability
属性,该属性由jax.devices()
或jax.local_devices()
返回。jax.numpy.reshape()
的newshape
参数即将弃用,并将很快被移除。请使用shape
代替。
更改
此版本的最低 jaxlib 版本为 0.4.27。
jaxlib 0.4.28 (2024年5月9日)#
错误修复
修复了 Python 3.10 或更低版本中 Array 和 JIT Python 对象的类型名称中的内存损坏错误。
修复了 CUDA 12.4 下的警告
'+ptx84' is not a recognized feature for this target
。修复了 CPU 上的编译速度慢的问题。
更改
Windows 版本现在使用 Clang 而不是 MSVC 构建。
jax 0.4.27 (2024年5月7日)#
新功能
添加了
jax.numpy.unstack()
和jax.numpy.cumulative_sum()
,遵循它们在 array API 2023 标准中的添加,该标准即将被 NumPy 采用。添加了一个新的配置选项
jax_cpu_collectives_implementation
,用于选择 CPU 后端使用的跨进程集体操作的实现。可用的选项有'none'
(默认)、'gloo'
和'mpi'
(需要 jaxlib 0.4.26)。如果设置为'none'
,则跨进程集体操作将被禁用。
更改
jax.pure_callback()
、jax.experimental.io_callback()
和jax.debug.callback()
现在使用jax.Array
而不是np.ndarray
。您可以通过在将参数传递给回调之前通过jax.tree.map(np.asarray, args)
将其转换来恢复旧行为。complex_arr.astype(bool)
现在遵循与 NumPy 相同的语义,在complex_arr
等于0 + 0j
时返回 False,否则返回 True。core.Token
现在是一个非平凡的类,它包装了一个jax.Array
。它可以被创建并在线程中进出计算以建立依赖关系。单例对象core.token
已被移除,用户现在应该创建并使用新的core.Token
对象。在 GPU 上,Threefry PRNG 实现默认不再降低到内核调用。此选择可以在编译时成本下提高运行时内存使用率。可以通过
jax.config.update('jax_threefry_gpu_kernel_lowering', True)
恢复之前产生内核调用的行为。如果新的默认值导致问题,请提交错误报告。否则,我们打算在未来的版本中移除此标志。
弃用和移除
Pallas 现在专门使用 XLA 来编译 GPU 上的内核。通过 Triton Python API 的旧降低传递已被移除,并且
JAX_TRITON_COMPILE_VIA_XLA
环境变量不再有任何作用。jax.numpy.clip()
有一个新的参数签名:a
、a_min
和a_max
已弃用,取而代之的是x
(仅限位置)、min
和max
(#20550)。JAX 数组的
device()
方法已移除,自 JAX v0.4.21 以来一直处于弃用状态。请改用arr.devices()
。传递给
jax.nn.softmax()
和jax.nn.log_softmax()
的initial
参数已弃用;现在支持 softmax 的空输入,无需设置此参数。在
jax.jit()
中,传递无效的static_argnums
或static_argnames
现在会导致错误,而不是警告。jaxlib 的最低版本现在为 0.4.23。
当
jax.numpy.hypot()
函数传递复数值输入时,现在会发出弃用警告。在弃用完成后,这将引发错误。传递给
jax.numpy.nonzero()
、jax.numpy.where()
及相关函数的标量参数现在会引发错误,这与 NumPy 中的类似更改一致。配置选项
jax_cpu_enable_gloo_collectives
已弃用。请改用jax.config.update('jax_cpu_collectives_implementation', 'gloo')
。在 JAX v0.4.22 中弃用后,
jax.Array.device_buffer
和jax.Array.device_buffers
方法已被移除。请改用jax.Array.addressable_shards
和jax.Array.addressable_data()
。现在,
jax.numpy.where
的condition
、x
和y
参数为仅限位置参数,这遵循 JAX v0.4.21 中关键字的弃用。现在,
jax.lax.linalg
中函数的非数组参数必须通过关键字指定。以前,这会引发 DeprecationWarning。现在,在几个 :func:
jax.numpy
API 中需要类数组参数,包括apply_along_axis()
、apply_over_axes()
、inner()
、outer()
、cross()
、kron()
和lexsort()
。
错误修复
当
copy=True
时,jax.numpy.astype()
现在将始终返回副本。以前,当输出数组与输入数组具有相同的 dtype 时,不会创建副本。这可能会导致内存使用量略有增加。默认值为copy=False
,以保持向后兼容性。
jaxlib 0.4.27 (2024年5月7日)#
jax 0.4.26 (2024年4月3日)#
新功能
添加了
jax.numpy.trapezoid()
,这遵循 NumPy 2.0 中添加此函数。
更改
复数值
jax.numpy.geomspace()
现在选择与 NumPy 2.0 一致的对数螺旋分支。在
jax.vmap
下,lax.rng_bit_generator
以及'rbg'
和'unsafe_rbg'
PRNG 实现的行为 已更改,以便批处理中的第一个密钥仅用于映射密钥上的随机生成。文档现在使用
jax.random.key
来构建 PRNG 密钥数组,而不是jax.random.PRNGKey
。
弃用和移除
jax.tree_map()
已弃用;请改用jax.tree.map
,或者为了与旧版 JAX 版本向后兼容,请使用jax.tree_util.tree_map()
。jax.clear_backends()
已弃用,因为它不一定能实现其名称所暗示的功能,并且可能导致意外的后果,例如,它不会销毁现有的后端并释放相应的已拥有资源。如果您只想清理编译缓存,请使用jax.clear_caches()
。为了向后兼容性或您确实需要切换/重新初始化默认后端,请使用jax.extend.backend.clear_backends()
。jax.experimental.maps
模块和jax.experimental.maps.xmap
已弃用。请使用jax.experimental.shard_map
或带有spmd_axis_name
参数的jax.vmap
来表达 SPMD 设备并行计算。jax.experimental.host_callback
模块已弃用。请改用 新的 JAX 外部回调。添加了JAX_HOST_CALLBACK_LEGACY
标志来帮助过渡到新的回调。有关讨论,请参阅 #20385。传递给
jax.numpy.array_equal()
和jax.numpy.array_equiv()
的无法转换为 JAX 数组的参数现在会导致异常。已删除弃用的标志
jax_parallel_functions_output_gda
。此标志已弃用很长时间,并且没有任何作用;其使用是无操作的。以前弃用的导入
jax.interpreters.ad.config
和jax.interpreters.ad.source_info_util
现已删除。请改用jax.config
和jax.extend.source_info_util
。JAX 导出不再支持旧的序列化版本。版本 9 自 2023 年 10 月 27 日起得到支持,并自 2024 年 2 月 1 日起成为默认版本。请参阅 版本说明。此更改可能会破坏设置低于 9 的特定 JAX 序列化版本的客户端。
jaxlib 0.4.26 (2024年4月3日)#
更改
JAX 现在仅支持 CUDA 12.1 或更高版本。对 CUDA 11.8 的支持已删除。
JAX 现在支持 NumPy 2.0。
jax 0.4.25 (2024年2月26日)#
新功能
添加了 CUDA 数组接口 导入支持(需要 jaxlib 0.4.24)。
JAX 数组现在支持 NumPy 样式的标量布尔索引,例如
x[True]
或x[False]
。添加了
jax.tree
模块,该模块提供了一个更方便的接口来引用jax.tree_util
中的函数。jax.tree.transpose()
(即jax.tree_util.tree_transpose()
)现在接受inner_treedef=None
,在这种情况下,将自动推断内部 treedef。
更改
Pallas 现在使用 XLA 而不是 Triton Python API 来编译 Triton 内核。您可以通过将
JAX_TRITON_COMPILE_VIA_XLA
环境变量设置为"0"
来恢复旧的行为。v0.4.24 中删除的
jax.interpreters.xla
中的一些已弃用 API 已在 v0.4.25 中重新添加,包括backend_specific_translations
、translations
、register_translation
、xla_destructure
、TranslationRule
、TranslationContext
和XLAOp
。这些仍然被认为已弃用,并且将在将来提供更好的替代方案后再次删除。有关讨论,请参阅 #19816。
弃用和移除
jax.numpy.linalg.solve()
现在对于批处理一维求解(其中b.ndim > 1
)显示弃用警告。未来,这些将被视为批处理二维求解。将非标量数组转换为 Python 标量现在会引发错误,无论数组大小如何。之前在大小为 1 的非标量数组情况下会发出弃用警告。这遵循 NumPy 中类似的弃用。
先前已弃用的配置 API 已在标准的 3 个月弃用周期后移除(请参阅 API 兼容性)。这些包括
jax.config.config
对象,以及jax.config
的define_*_state
和DEFINE_*
方法。
通过
import jax.config
导入jax.config
子模块已被弃用。要配置 JAX,请使用import jax
,然后通过jax.config
引用配置对象。jaxlib 的最低版本现在为 0.4.20。
jaxlib 0.4.25 (2024 年 2 月 26 日)#
jax 0.4.24 (2024 年 2 月 6 日)#
更改
JAX 降级到 StableHLO 不再依赖于物理设备。如果您的原语在降级规则中包装了 custom_partitioning 或 JAX 回调,例如传递给
mlir.register_lowering
的rule
参数的函数,则将您的原语添加到jax._src.dispatch.prim_requires_devices_during_lowering
集合中。这是必需的,因为 custom_partitioning 和 JAX 回调需要物理设备在降级期间创建Sharding
。这是一种临时状态,直到我们能够在没有物理设备的情况下创建Sharding
。jax.numpy.argsort()
和jax.numpy.sort()
现在支持stable
和descending
参数。对形状多态性处理的若干更改(用于
jax.experimental.jax2tf
和jax.experimental.export
)更清晰的符号表达式的漂亮打印 (#19227)
添加了指定维度变量符号约束的功能。这使得形状多态性更具表现力,并提供了一种解决推理不等式限制的方法。请参阅 https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints。
随着符号约束的添加 (#19235),我们现在认为来自不同范围的维度变量是不同的,即使它们具有相同的名称。来自不同范围的符号表达式不能交互,例如,在算术运算中。范围由
jax.experimental.jax2tf.convert()
、jax.experimental.export.symbolic_shape()
、jax.experimental.export.symbolic_args_specs()
引入。符号表达式e
的范围可以通过e.scope
读取,并传递给上述函数以指示它们在给定范围内构造符号表达式。请参阅 https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints。简化并加快了相等比较,其中我们认为如果两个符号维度的差的标准化形式简化为 0,则它们相等 (#19231;请注意,这可能会导致用户可见的行为更改)
改进了不确定的不等式比较的错误消息 (#19235)。
core.non_negative_dim
API(最近引入)已弃用,并引入了core.max_dim
和core.min_dim
(#18953) 来表示符号维度的max
和min
。您可以使用core.max_dim(d, 0)
代替core.non_negative_dim(d)
。shape_poly.is_poly_dim
已弃用,推荐使用export.is_symbolic_dim
(#19282)。export.args_specs
已弃用,推荐使用export.symbolic_args_specs ({jax-issue}
#19283`)。shape_poly.PolyShape
和jax2tf.PolyShape
已弃用,请使用字符串表示多态形状规范 (#19284)。JAX 默认的原生序列化版本现在为 9。这与
jax.experimental.jax2tf
和jax.experimental.export
相关。请参阅 版本号说明。
重构了
jax.experimental.export
的 API。您现在应该使用from jax.experimental import export
而不是from jax.experimental.export import export
。旧的导入方式将在 3 个月的弃用期内继续有效。jax.numpy.unique()
当return_inverse = True
时,会返回重新整形为输入维度的逆索引,这与 NumPy 2.0 中对numpy.unique()
的类似更改一致。jax.numpy.sign()
现在对于非零复数输入返回x / abs(x)
。这与 NumPy 2.0 版本中numpy.sign()
的行为一致。jax.scipy.special.logsumexp()
当return_sign=True
时,现在使用 NumPy 2.0 中复数符号的约定,即x / abs(x)
。这与 SciPy v1.13 中scipy.special.logsumexp()
的行为一致。JAX 现在支持导入和导出 bool DLPack 类型。以前,bool 值无法导入,并作为整数导出。
弃用和移除
许多先前已弃用的函数已在标准的 3 个月以上弃用周期后移除(请参阅 API 兼容性)。这包括
来自
jax.core
:TracerArrayConversionError
、TracerIntegerConversionError
、UnexpectedTracerError
、as_hashable_function
、collections
、dtypes
、lu
、map
、namedtuple
、partial
、pp
、ref
、safe_zip
、safe_map
、source_info_util
、total_ordering
、traceback_util
、tuple_delete
、tuple_insert
和zip
。来自
jax.lax
:dtypes
、itertools
、naryop
、naryop_dtype_rule
、standard_abstract_eval
、standard_naryop
、standard_primitive
、standard_unop
、unop
和unop_dtype_rule
。jax.linear_util
子模块及其所有内容。jax.prng
子模块及其所有内容。来自
jax.random
:PRNGKeyArray
、KeyArray
、default_prng_impl
、threefry_2x32
、threefry2x32_key
、threefry2x32_p
、rbg_key
和unsafe_rbg_key
。来自
jax.tree_util
:register_keypaths
、AttributeKeyPathEntry
和GetItemKeyPathEntry
。从
jax.interpreters.xla
中导入:backend_specific_translations
、translations
、register_translation
、xla_destructure
、TranslationRule
、TranslationContext
、axis_groups
、ShapedArray
、ConcreteArray
、AxisEnv
、backend_compile
和XLAOp
。从
jax.numpy
中导入:NINF
、NZERO
、PZERO
、row_stack
、issubsctype
、trapz
和in1d
。从
jax.scipy.linalg
中导入:tril
和triu
。
先前已弃用的方法
PRNGKeyArray.unsafe_raw_array
现已移除。请改用jax.random.key_data()
。bool(empty_array)
现在会抛出错误,而不是返回False
。之前会发出弃用警告,此更改遵循 NumPy 中的类似更改。对 mhlo MLIR 方言的支持已弃用。JAX 不再使用 mhlo 方言,而是使用 stablehlo。将来会移除引用“mhlo”的 API。请改用“stablehlo”方言。
jax.random
:将批处理密钥直接传递给随机数生成函数(例如bits()
、gamma()
等)已弃用,并将发出FutureWarning
。请使用jax.vmap
进行显式批处理。jax.lax.tie_in()
已弃用:自 JAX v0.2.0 以来,它一直是无操作的。
jaxlib 0.4.24 (2024年2月6日)#
更改
JAX 现在支持 CUDA 12.3 和 CUDA 11.8。已放弃对 CUDA 12.2 的支持。
cost_analysis
现在可与交叉编译的Compiled
对象一起使用(即,当使用.lower().compile()
和拓扑对象时,例如,从非 TPU 计算机编译到 Cloud TPU)。添加了 CUDA 数组接口 导入支持(需要 jax 0.4.25)。
jax 0.4.23 (2023年12月13日)#
jaxlib 0.4.23 (2023年12月13日)#
修复了导致编译期间 GPU 编译器输出详细日志的错误。
jax 0.4.22 (2023年12月13日)#
弃用
JAX 数组的
device_buffer
和device_buffers
属性已弃用。显式缓冲区已被更灵活的数组分片接口取代,但可以通过以下方式恢复之前的输出arr.device_buffer
变为arr.addressable_data(0)
arr.device_buffers
变为[x.data for x in arr.addressable_shards]
jaxlib 0.4.22 (2023年12月13日)#
jax 0.4.21 (2023年12月4日)#
新功能
添加了
jax.nn.squareplus
。
更改
jaxlib 的最低版本现在为 0.4.19。
发布的轮子现在使用 clang 而不是 gcc 构建。
强制在调用
jax.distributed.initialize()
之前设备后端尚未初始化。在 Cloud TPU 环境中自动为
jax.distributed.initialize()
提供参数。
弃用
先前已弃用的
sym_pos
参数已从jax.scipy.linalg.solve()
中移除。请改用assume_a='pos'
。将
None
直接传递给jax.array()
或jax.asarray()
,或者在列表或元组中传递,已弃用,现在会引发FutureWarning
。它目前会被转换为 NaN,将来会引发TypeError
。通过关键字参数将
condition
、x
和y
参数传递给jax.numpy.where
已弃用,以匹配numpy.where
。将无法转换为 JAX 数组的参数传递给
jax.numpy.array_equal()
和jax.numpy.array_equiv()
已弃用,现在会引发DeprecationWaning
。目前这些函数会返回 False,将来会引发异常。JAX 数组的
device()
方法已弃用。根据上下文,可以使用以下方法之一替换它jax.Array.devices()
返回数组使用的所有设备的集合。jax.Array.sharding
提供数组使用的分片配置。
jaxlib 0.4.21 (2023年12月4日)#
更改
为了准备添加分布式 CPU 支持,JAX 现在将 CPU 设备与 GPU 和 TPU 设备等同对待,即
jax.devices()
包括分布式作业中存在的所有设备,即使这些设备不是当前进程的本地设备。jax.local_devices()
仍然只包括当前进程的本地设备,因此,如果jax.devices()
的更改导致您的程序出现问题,则您很可能需要改用jax.local_devices()
。CPU 设备现在在分布式作业中接收全局唯一的 ID 号;之前 CPU 设备会接收进程本地 ID 号。
每个 CPU 设备的
process_index
现在将与同一进程中的任何 GPU 或 TPU 设备匹配;之前 CPU 设备的process_index
始终为 0。
在 NVIDIA GPU 上,JAX 现在更倾向于对最多 1024x1024 的矩阵使用 Jacobi SVD 求解器。Jacobi 求解器似乎比非 Jacobi 版本更快。
错误修复
修复了将包含非有限值的数组传递给非对称特征分解时出现的错误/挂起问题 (#18226)。包含非有限值的数组现在会生成充满 NaN 的数组作为输出。
jax 0.4.20 (2023年11月2日)#
jaxlib 0.4.20 (2023年11月2日)#
错误修复
修复了 E4M3 和 E5M2 float8 类型之间的一些类型混淆问题。
jax 0.4.19 (2023年10月19日)#
新功能
添加了
jax.typing.DTypeLike
,可用于注释可转换为 JAX 数据类型的对象。添加了
jax.numpy.fill_diagonal
。
更改
JAX 现在需要 SciPy 1.9 或更高版本。
错误修复
在多控制器分布式 JAX 程序中,只有进程 0 会写入持久编译缓存条目。如果缓存放在网络文件系统(如 GCS)上,这可以修复写入争用问题。
在确定已安装的 cusolver 和 cufft 版本是否至少与 JAX 构建时使用的版本一样新时,版本检查不再考虑补丁版本。
jaxlib 0.4.19 (2023年10月19日)#
更改
如果安装了 pip 安装的 NVIDIA CUDA 库(nvidia-... 包),jaxlib 现在将始终优先于任何其他 CUDA 安装(包括在
LD_LIBRARY_PATH
中命名的安装)。如果这导致问题并且目的是使用系统安装的 CUDA,则解决方法是删除 pip 安装的 CUDA 库包。
jax 0.4.18 (2023年10月6日)#
jaxlib 0.4.18 (2023年10月6日)#
更改
CUDA jaxlib 现在依赖用户安装兼容的 NCCL 版本。如果使用推荐的
cuda12_pip
安装,则应自动安装 NCCL。目前,需要 NCCL 2.16 或更高版本。我们现在提供 Linux aarch64 轮子,包括有和没有 NVIDIA GPU 支持的版本。
jax.Array.item()
现在支持可选的索引参数。
弃用
jax.lax
中的一些内部实用程序和意外导出已弃用,将在未来版本中移除。jax.lax.dtypes
:请改用jax.dtypes
。jax.lax.itertools
:请改用itertools
。naryop
、naryop_dtype_rule
、standard_abstract_eval
、standard_naryop
、standard_primitive
、standard_unop
、unop
和unop_dtype_rule
是内部实用程序,现已弃用且无替代方案。
错误修复
修复了 Cloud TPU 回归,该回归会导致编译因 smem 而出现内存不足错误。
jax 0.4.17 (2023年10月3日)#
新功能
添加了新的
jax.numpy.bitwise_count()
函数,与最近添加到 NumPy 的类似函数的 API 匹配。
弃用
删除了已弃用的模块
jax.abstract_arrays
及其所有内容。已弃用
jax.random
中的命名键构造器。请改用将impl
参数传递给jax.random.PRNGKey()
或jax.random.key()
random.threefry2x32_key(seed)
变为random.PRNGKey(seed, impl='threefry2x32')
random.rbg_key(seed)
变为random.PRNGKey(seed, impl='rbg')
random.unsafe_rbg_key(seed)
变为random.PRNGKey(seed, impl='unsafe_rbg')
更改
CUDA:JAX 现在会验证它找到的 CUDA 库是否至少与构建 JAX 时所用的 CUDA 库一样新。如果找到较旧的库,JAX 会引发异常,因为这比神秘的故障和崩溃更可取。
删除了“未找到 GPU/TPU”警告。改为在 Linux 上,如果找到 NVIDIA GPU 或 Google TPU 但未使用且未指定
--jax_platforms
时发出警告。jax.scipy.stats.mode()
现在如果模式是在大小为 0 的轴上获取的,则返回 0 个计数,这与 SciPy 1.11 中scipy.stats.mode
的行为一致。大多数
jax.numpy
函数和属性现在都具有完全定义的类型存根。以前,其中许多被像mypy
和pytype
这样的静态类型检查器视为Any
。
jaxlib 0.4.17 (2023年10月3日)#
更改
在此版本中添加了 Python 3.12 轮子。
CUDA 12 轮子现在需要 CUDA 12.2 或更高版本以及 cuDNN 8.9.4 或更高版本。
错误修复
修复了初始化 JAX CPU 后端时 ABSL 发出的日志垃圾邮件。
jax 0.4.16 (2023年9月18日)#
更改
添加了
jax.numpy.ufunc
,以及jax.numpy.frompyfunc()
,它可以将任何标量值函数转换为类似numpy.ufunc()
的对象,并具有诸如outer()
、reduce()
、accumulate()
、at()
和reduceat()
的方法 (#17054)。在未在 IPython 下运行时:当引发异常时,JAX 现在会从回溯中过滤掉其所有内部帧。(没有以前出现的“未过滤的堆栈跟踪”。)这应该会产生更友好的回溯。请参阅 此处 以了解示例。可以通过设置
JAX_TRACEBACK_FILTERING=remove_frames
(对于两个单独的未过滤/已过滤回溯,这是旧的行为)或JAX_TRACEBACK_FILTERING=off
(对于一个未过滤的回溯)来更改此行为。jax2tf 默认序列化版本现在为 7,它引入了新的形状 安全断言。
传递给
jax.sharding.Mesh
的设备应该是可散列的。这特别适用于模拟设备或用户创建的设备。jax.devices()
已经是可散列的。
重大更改
jax2tf 现在默认使用原生序列化。有关详细信息以及覆盖默认值的机制,请参阅 jax2tf 文档。
已删除选项
--jax_coordination_service
。它现在始终为True
。已从公共 JAX 命名空间中删除了
jax.jaxpr_util
。JAX_USE_PJRT_C_API_ON_TPU
不再有任何作用(即,它始终默认为 true)。已删除 2021 年 12 月引入的向后兼容性标志
--jax_host_callback_ad_transforms
。
弃用
遵循 NumPy NEP-52,已弃用了一些
jax.numpy
API。已弃用
jax.numpy.NINF
。请改用-jax.numpy.inf
。已弃用
jax.numpy.PZERO
。请改用0.0
。已弃用
jax.numpy.NZERO
。请改用-0.0
。已弃用
jax.numpy.issubsctype(x, t)
。请改用jax.numpy.issubdtype(x.dtype, t)
。已弃用
jax.numpy.row_stack
。请改用jax.numpy.vstack
。已弃用
jax.numpy.in1d
。请改用jax.numpy.isin
。已弃用
jax.numpy.trapz
。请改用jax.scipy.integrate.trapezoid
。
已弃用
jax.scipy.linalg.tril
和jax.scipy.linalg.triu
,遵循 SciPy。请改用jax.numpy.tril
和jax.numpy.triu
。在 JAX v0.4.11 中弃用后,已删除
jax.lax.prod
。请改用内置的math.prod
。已弃用与为自定义 JAX 原语定义 HLO 降级规则相关的
jax.interpreters.xla
中的一些导出内容。请改用jax.interpreters.mlir
中的 StableHLO 降级实用程序来定义自定义原语。以下先前已弃用的函数在经过三个月的弃用期后已被删除
jax.abstract_arrays.ShapedArray
:请使用jax.core.ShapedArray
。jax.abstract_arrays.raise_to_shaped
:请使用jax.core.raise_to_shaped
。jax.numpy.alltrue
:请使用jax.numpy.all
。jax.numpy.sometrue
:请使用jax.numpy.any
。jax.numpy.product
:请使用jax.numpy.prod
。jax.numpy.cumproduct
:请使用jax.numpy.cumprod
。
弃用/删除
内部子模块
jax.prng
现已弃用。其内容可在jax.extend.random
中找到。内部子模块路径
jax.linear_util
已弃用。请改用jax.extend.linear_util
(jax.extend:扩展模块的一部分)jax.random.PRNGKeyArray
和jax.random.KeyArray
已弃用。请使用jax.Array
进行类型注释,并使用jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key)
进行类型化 prng 键的运行时检测。方法
PRNGKeyArray.unsafe_raw_array
已弃用。请改用jax.random.key_data()
。jax.experimental.pjit.with_sharding_constraint
已弃用。请改用jax.lax.with_sharding_constraint
。已删除内部实用程序
jax.core.is_opaque_dtype
和jax.core.has_opaque_dtype
。不透明数据类型已重命名为扩展数据类型;请改用jnp.issubdtype(dtype, jax.dtypes.extended)
(从 jax v0.4.14 开始可用)。已删除实用程序
jax.interpreters.xla.register_collective_primitive
。此实用程序在最近的 JAX 版本中没有任何用处,可以安全地删除对它的调用。内部子模块路径
jax.linear_util
已弃用。请改用jax.extend.linear_util
(jax.extend:扩展模块的一部分)
jaxlib 0.4.16 (2023年9月18日)#
更改
通过实验性 jax sparse API 进行的稀疏 CSR 矩阵乘法不再在 NVIDIA GPU 上使用确定性算法。进行此更改是为了提高与 CUDA 12.2.1 的兼容性。
错误修复
修复了 Windows 上由于与乱序节和 IMAGE_REL_AMD64_ADDR32NB 重定位相关的致命 LLVM 错误导致的崩溃 (https://github.com/openxla/xla/commit/cb732a921f0c4184995cbed82394931011d12bd4)。
jax 0.4.14 (2023年7月27日)#
更改
jax.jit
接受donate_argnames
作为参数。其语义类似于static_argnames
。如果既没有提供 donate_argnums 也没有提供 donate_argnames,则不会捐赠任何参数。如果未提供 donate_argnums 但提供了 donate_argnames,反之亦然,JAX 使用inspect.signature(fun)
查找与 donate_argnames 对应的任何位置参数(反之亦然)。如果同时提供了 donate_argnums 和 donate_argnames,则不会使用 inspect.signature,并且只会捐赠在 donate_argnums 或 donate_argnames 中列出的实际参数。jax.random.gamma()
已重构为一种更高效的算法,具有更强大的端点行为 (#16779)。这意味着对于给定的key
返回的值序列将在 JAX v0.4.13 和 v0.4.14 之间发生变化,对于gamma
和相关的采样器(包括jax.random.ball()
、jax.random.beta()
、jax.random.chisquare()
、jax.random.dirichlet()
、jax.random.generalized_normal()
、jax.random.loggamma()
、jax.random.t()
)。
删除
由于自弃用以来已超过 3 个月,
in_axis_resources
和out_axis_resources
已从 pjit 中删除。请使用in_shardings
和out_shardings
作为替换。这是一个安全且简单的名称替换。它不会更改任何当前的 pjit 语义,也不会破坏任何代码。您仍然可以将PartitionSpecs
传递给 in_shardings 和 out_shardings。
弃用
根据 https://jax.ac.cn/en/latest/deprecation.html 已放弃对 Python 3.8 的支持
根据 https://jax.ac.cn/en/latest/deprecation.html,JAX 现在需要 NumPy 1.22 或更高版本
通过位置传递可选参数到
jax.numpy.ndarray.at()
已不再支持,在 JAX 版本 0.4.7 中已弃用。例如,不要使用x.at[i].get(True)
,请使用x.at[i].get(indices_are_sorted=True)
以下
jax.Array
方法已删除,在 JAX v0.4.5 中已弃用jax.Array.broadcast
:请改用jax.lax.broadcast()
。jax.Array.broadcast_in_dim
:请改用jax.lax.broadcast_in_dim()
。jax.Array.split
:请改用jax.numpy.split()
。
以下 API 在之前弃用后已被删除
jax.ad
:请使用jax.interpreters.ad
。jax.curry
:请使用curry = lambda f: partial(partial, f)
。jax.partial_eval
:请使用jax.interpreters.partial_eval
。jax.pxla
:请使用jax.interpreters.pxla
。jax.xla
:请使用jax.interpreters.xla
。jax.ShapedArray
:请使用jax.core.ShapedArray
。jax.interpreters.pxla.device_put
:请使用jax.device_put()
。jax.interpreters.pxla.make_sharded_device_array
:请使用jax.make_array_from_single_device_arrays()
。jax.interpreters.pxla.ShardedDeviceArray
:请使用jax.Array
。jax.numpy.DeviceArray
:请使用jax.Array
。jax.stages.Compiled.compiler_ir
:请使用jax.stages.Compiled.as_text()
。
重大更改
JAX 现在需要 ml_dtypes 版本 0.2.0 或更高版本。
为了修复一个极端情况,使用五个参数调用
jax.lax.cond()
将始终解析为“公共操作数”cond
行为(如文档中所述),如果第二个和第三个参数是可调用的,即使其他操作数也是可调用的。请参阅 #16413。已删除弃用的配置选项
jax_array
和jax_jit_pjit_api_merge
,它们什么也没做。这些选项在许多版本中默认都为 true。
新功能
JAX 现在支持一个配置标志 –jax_serialization_version 和一个 JAX_SERIALIZATION_VERSION 环境变量来控制序列化版本 (#16746)。
在存在形状多态性的情况下,jax2tf 现在会生成检查某些形状约束的代码,如果序列化版本至少为 7。请参阅 https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism。
jaxlib 0.4.14 (2023年7月27日)#
弃用
根据 https://jax.ac.cn/en/latest/deprecation.html 已放弃对 Python 3.8 的支持
jax 0.4.13 (2023年6月22日)#
更改
jax.jit
现在允许将None
传递给in_shardings
和out_shardings
。语义如下对于 in_shardings,JAX 将将其标记为复制,但此行为将来可能会更改。
对于 out_shardings,我们将依靠 XLA GSPMD 分区器来确定输出分片。
jax.experimental.pjit.pjit
也允许将None
传递给in_shardings
和out_shardings
。语义如下如果 未 提供网格上下文管理器,则 JAX 可以自由选择任何它想要的分片。
对于 in_shardings,JAX 将将其标记为复制,但此行为将来可能会更改。
对于 out_shardings,我们将依靠 XLA GSPMD 分区器来确定输出分片。
如果提供了网格上下文管理器,则 None 将意味着该值将在网格的所有设备上复制。
Executable.cost_analysis() 在 Cloud TPU 上有效
如果正在使用非允许列表的
jaxlib
插件,则添加了一个警告。添加了
jax.tree_util.tree_leaves_with_path
。None
不是jax.experimental.multihost_utils.host_local_array_to_global_array
或jax.experimental.multihost_utils.global_array_to_host_local_array
的有效输入。如果您想复制输入,请使用jax.sharding.PartitionSpec()
。
错误修复
修复了 CUDA 12 版本中不正确的轮子名称 (#16362);正确的轮子名为
cudnn89
而不是cudnn88
。
弃用
jax.experimental.jax2tf.convert()
的native_serialization_strict_checks
参数已弃用,取而代之的是新的native_serializaation_disabled_checks
(#16347)。
jaxlib 0.4.13 (2023年6月22日)#
更改
将 Windows CPU 仅轮子添加到
jaxlib
Pypi 版本中。
错误修复
__cuda_array_interface__
在之前的 jaxlib 版本中已损坏,现在已修复 (#16440)。并发 CUDA 内核跟踪现在在 NVIDIA GPU 上默认启用。
jax 0.4.12 (2023年6月8日)#
更改
弃用
jax.abstract_arrays
及其内容现已弃用。请参阅 :mod:jax.core
中的相关功能。jax.numpy.alltrue
:请使用jax.numpy.all
。这遵循 NumPy 版本 1.25.0 中numpy.alltrue
的弃用。jax.numpy.sometrue
:请使用jax.numpy.any
。此更改遵循 NumPy 1.25.0 版本中numpy.sometrue
的弃用。jax.numpy.product
:请使用jax.numpy.prod
。此更改遵循 NumPy 1.25.0 版本中numpy.product
的弃用。jax.numpy.cumproduct
:请使用jax.numpy.cumprod
。此更改遵循 NumPy 1.25.0 版本中numpy.cumproduct
的弃用。jax.sharding.OpShardingSharding
由于已弃用 3 个月,现已移除。
jaxlib 0.4.12 (2023 年 6 月 8 日)#
更改
包含适用于 Hopper(SM 版本 9.0+)GPU 的 PTX/SASS。先前版本的 jaxlib 应该可以在 Hopper 上运行,但第一次执行 JAX 操作时会存在较长的 JIT 编译延迟。
错误修复
修复了 Python 3.11 下 JAX 生成的 Python 追溯信息中源代码行信息不正确的问题。
修复了在 JAX 生成的 Python 追溯中打印帧的局部变量时发生的崩溃问题 (#16027)。
jax 0.4.11 (2023 年 5 月 31 日)#
弃用
根据 API 兼容性 策略,以下 API 在 3 个月的弃用期后已被移除
jax.experimental.PartitionSpec
:请使用jax.sharding.PartitionSpec
。jax.experimental.maps.Mesh
:请使用jax.sharding.Mesh
jax.experimental.pjit.NamedSharding
:请使用jax.sharding.NamedSharding
。jax.experimental.pjit.PartitionSpec
:请使用jax.sharding.PartitionSpec
。jax.experimental.pjit.FROM_GDA
。请改为将分片后的jax.Array
对象作为输入,并移除pjit
的可选参数in_shardings
。jax.interpreters.pxla.PartitionSpec
:请使用jax.sharding.PartitionSpec
。jax.interpreters.pxla.Mesh
:请使用jax.sharding.Mesh
jax.interpreters.xla.Buffer
:请使用jax.Array
。jax.interpreters.xla.Device
:请使用jax.Device
。jax.interpreters.xla.DeviceArray
:请使用jax.Array
。jax.interpreters.xla.device_put
:请使用jax.device_put
。jax.interpreters.xla.xla_call_p
:请使用jax.experimental.pjit.pjit_p
。with_sharding_constraint
的参数axis_resources
已移除。请使用shardings
代替。
jaxlib 0.4.11 (2023 年 5 月 31 日)#
更改
为
Device
添加了memory_stats()
方法。如果支持,则返回一个包含字符串统计名称和整数值的字典,例如"bytes_in_use"
,或者如果平台不支持内存统计则返回 None。返回的具体统计信息可能因平台而异。目前仅在 Cloud TPU 上实现。重新添加了对 CPU 设备上 Python 缓冲协议(
memoryview
)的支持。
jax 0.4.10 (2023 年 5 月 11 日)#
jaxlib 0.4.10 (2023 年 5 月 11 日)#
更改
修复了阻止先前版本在 Mac M1 上运行的
'apple-m1' is not a recognized processor for this target (ignoring processor)
问题。
jax 0.4.9 (2023 年 5 月 9 日)#
更改
标志 experimental_cpp_jit、experimental_cpp_pjit 和 experimental_cpp_pmap 已移除。它们现在始终处于开启状态。
提高了 TPU 上奇异值分解 (SVD) 的精度(需要 jaxlib 0.4.9)。
弃用
jax.experimental.gda_serialization
已弃用,并已重命名为jax.experimental.array_serialization
。请更改您的导入以使用jax.experimental.array_serialization
。pjit 的参数
in_axis_resources
和out_axis_resources
已弃用。请分别使用in_shardings
和out_shardings
。函数
jax.numpy.msort
已移除。它自 JAX v0.4.1 起已弃用。请改用jnp.sort(a, axis=0)
。in_parts
和out_parts
参数已从jax.xla_computation
中移除,因为它们仅与 sharded_jit 一起使用,而 sharded_jit 已被弃用很久了。instantiate_const_outputs
参数已从jax.xla_computation
中移除,因为它很久以来一直未使用。
jaxlib 0.4.9 (2023 年 5 月 9 日)#
jax 0.4.8 (2023 年 3 月 29 日)#
重大更改
Cloud TPU 运行时的主要组件已升级。这使得 Cloud TPU 上可以启用以下新功能
jax.debug.print()
、jax.debug.callback()
和jax.debug.breakpoint()
现在可以在 Cloud TPU 上使用自动 TPU 内存碎片整理
jax.experimental.host_callback()
在使用新运行时组件的 Cloud TPU 上不再受支持。如果您发现新的jax.debug
API 不足以满足您的用例,请在 JAX 问题跟踪器 上提交问题。通过设置环境变量
JAX_USE_PJRT_C_API_ON_TPU=false
,旧的运行时组件至少将在未来三个月内可用。如果您发现出于任何原因需要禁用新运行时,请在 JAX 问题跟踪器 上告知我们。
更改
jaxlib 的最低版本已从 0.4.6 提升到 0.4.7。
弃用
已放弃对 CUDA 11.4 的支持。JAX GPU 轮子仅支持 CUDA 11.8 和 CUDA 12。如果从源代码构建 jaxlib,则较旧的 CUDA 版本可能有效。
pmap 的参数
global_arg_shapes
仅与 sharded_jit 一起使用,现已从 pmap 中移除。请迁移到 pjit 并从 pmap 中移除 global_arg_shapes。
jax 0.4.7 (2023 年 3 月 27 日)#
更改
根据 https://jax.ac.cn/en/latest/jax_array_migration.html#jax-array-migration,
jax.config.jax_array
无法再禁用。jax.config.jax_jit_pjit_api_merge
无法再禁用。jax.experimental.jax2tf.convert()
现在支持native_serialization
参数,以使用 JAX 的原生降低到 StableHLO,从而获得整个 JAX 函数的 StableHLO 模块,而不是将每个 JAX 原语降低到 TensorFlow 操作。这简化了内部结构,并提高了您序列化内容与 JAX 原生语义匹配的信心。请参见 文档。作为此更改的一部分,配置标志--jax2tf_default_experimental_native_lowering
已重命名为--jax2tf_native_serialization
。JAX 现在依赖于
ml_dtypes
,其中包含 NumPy 类型(如 bfloat16)的定义。这些定义以前是 JAX 的内部定义,但现在已拆分为一个单独的包,以便与其他项目共享。JAX 现在需要 NumPy 1.21 或更高版本以及 SciPy 1.7 或更高版本。
弃用
类型
jax.numpy.DeviceArray
已弃用。请改用jax.Array
,它是一个别名。类型
jax.interpreters.pxla.ShardedDeviceArray
已弃用。请改用jax.Array
。通过位置向
jax.numpy.ndarray.at()
传递其他参数已弃用。例如,请改用x.at[i].get(indices_are_sorted=True)
,而不是x.at[i].get(True)
。jax.interpreters.xla.device_put
已弃用。请使用jax.device_put
。jax.interpreters.pxla.device_put
已弃用。请使用jax.device_put
。jax.experimental.pjit.FROM_GDA
已弃用。请将分片 Jax 数组作为输入传递,并删除in_shardings
参数,因为它现在是可选的。
jaxlib 0.4.7 (2023年3月27日)#
更改
jaxlib 现在依赖于
ml_dtypes
,其中包含 bfloat16 等 NumPy 类型定义。这些定义以前是 JAX 内部使用的,但现在已拆分为单独的包,以便与其他项目共享。
jax 0.4.6 (2023年3月9日)#
更改
jax.tree_util
现在包含一组 API,允许用户为其自定义 pytree 节点定义键。这包括tree_flatten_with_path
,它会展平树并返回每个叶子及其键路径。tree_map_with_path
,它可以映射一个函数,该函数将键路径作为参数。register_pytree_with_keys
,用于注册自定义 pytree 节点中键路径和叶子应如何呈现。keystr
,用于美化打印键路径。
jax2tf.call_tf()
有一个新的参数output_shape_dtype
(默认为None
),可用于声明结果的输出形状和类型。这使得jax2tf.call_tf()
能够在存在形状多态性的情况下工作。(#14734)。
弃用
jax.tree_util
中旧的键路径 API 已弃用,并将于 2023 年 3 月 10 日起 3 个月后删除。register_keypaths
:请改用jax.tree_util.register_pytree_with_keys()
。AttributeKeyPathEntry
:请改用GetAttrKey
。GetitemKeyPathEntry
:请改用SequenceKey
或DictKey
。
jaxlib 0.4.6 (2023年3月9日)#
jax 0.4.5 (2023年3月2日)#
弃用
jax.sharding.OpShardingSharding
已重命名为jax.sharding.GSPMDSharding
。jax.sharding.OpShardingSharding
将于 2023 年 2 月 17 日起 3 个月后删除。以下
jax.Array
方法已弃用,并将于 2023 年 2 月 23 日起 3 个月后删除。jax.Array.broadcast
:请改用jax.lax.broadcast()
。jax.Array.broadcast_in_dim
:请改用jax.lax.broadcast_in_dim()
。jax.Array.split
:请改用jax.numpy.split()
。
jax 0.4.4 (2023年2月16日)#
更改
jit
和pjit
的实现已合并。合并 jit 和 pjit 会更改 JAX 的内部结构,但不会影响 JAX 的公共 API。之前,jit
是一种最终风格的原语。最终风格意味着 jaxpr 的创建尽可能延迟,并且转换会堆叠在一起。通过jit
-pjit
实现合并,jit
成为了一种初始风格的原语,这意味着我们会尽早跟踪到 jaxpr。有关更多信息,请参阅 autodidax 中的此部分。转向初始风格应该可以简化 JAX 的内部结构,并使动态形状等功能的开发更容易。您只能通过环境变量禁用它,即os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'
。合并必须通过环境变量禁用,因为它会影响 JAX 在导入时的行为,因此需要在导入 jax 之前禁用它。with_sharding_constraint
的axis_resources
参数已弃用。请改用shardings
。如果您将axis_resources
作为参数使用,则无需进行任何更改。如果您将其作为关键字参数使用,则请改用shardings
。axis_resources
将于 2023 年 2 月 13 日起 3 个月后删除。添加了
jax.typing
模块,其中包含用于 JAX 函数类型注释的工具。以下名称已弃用
jax.xla.Device
和jax.interpreters.xla.Device
:请改用jax.Device
。jax.experimental.maps.Mesh
。请改用jax.sharding.Mesh
。jax.experimental.pjit.NamedSharding
:请使用jax.sharding.NamedSharding
。jax.experimental.pjit.PartitionSpec
:请使用jax.sharding.PartitionSpec
。jax.interpreters.pxla.Mesh
:请改用jax.sharding.Mesh
。jax.interpreters.pxla.PartitionSpec
:请使用jax.sharding.PartitionSpec
。
重大更改
还原函数(如 :func:
jax.numpy.sum
)的initial
参数现在要求为标量,与相应的 NumPy API 保持一致。之前针对非标量initial
值广播输出的行为是一个意外的实现细节(#14446)。
jaxlib 0.4.4 (2023年2月16日)#
重大更改
默认
jaxlib
版本已删除对 NVIDIA Kepler 系列 GPU 的支持。如果需要 Kepler 支持,仍然可以通过源代码构建jaxlib
并启用 Kepler 支持(通过--cuda_compute_capabilities=sm_35
选项传递给build.py
),但请注意,CUDA 12 已完全放弃对 Kepler GPU 的支持。
jax 0.4.3 (2023年2月8日)#
重大更改
删除了
jax.scipy.linalg.polar_unitary()
,它是一个已弃用的 JAX 对 scipy API 的扩展。请改用jax.scipy.linalg.polar()
。
更改
jaxlib 0.4.3 (2023年2月8日)#
jax.Array
现在具有非阻塞is_ready()
方法,如果数组已准备好,则返回True
(另请参阅jax.block_until_ready()
)。
jax 0.4.2 (2023年1月24日)#
重大更改
更改
jaxlib 0.4.2 (2023年1月24日)#
更改
设置 JAX_USE_PJRT_C_API_ON_TPU=1 以启用新的 Cloud TPU 运行时,该运行时具有自动设备内存碎片整理功能。
jax 0.4.1 (2022年12月13日)#
更改
已停止支持 Python 3.7,这符合 JAX 的 Python 和 NumPy 版本支持策略。
我们引入了
jax.Array
,这是一种统一的数组类型,它包含了 JAX 中的DeviceArray
、ShardedDeviceArray
和GlobalDeviceArray
类型。jax.Array
类型有助于使并行性成为 JAX 的核心功能,简化并统一 JAX 的内部结构,并允许我们统一jit
和pjit
。jax.Array
已在 JAX 0.4 中默认启用,并且对pjit
API 产生了一些重大更改。 jax.Array 迁移指南 可以帮助您将代码库迁移到jax.Array
。您还可以查看 分布式数组和自动并行化 教程以了解新概念。PartitionSpec
和Mesh
现在已退出实验阶段。新的 API 端点为jax.sharding.PartitionSpec
和jax.sharding.Mesh
。jax.experimental.maps.Mesh
和jax.experimental.PartitionSpec
已弃用,并将在 3 个月内移除。with_sharding_constraint
的新公共端点为jax.lax.with_sharding_constraint
。如果将 ABSL 标志与
jax.config
结合使用,则在 JAX 配置选项最初从 ABSL 标志填充后,将不再读取或写入 ABSL 标志值。此更改提高了读取jax.config
选项的性能,这些选项在 JAX 中被广泛使用。jax2tf.call_tf 函数现在使用与嵌入式 JAX 计算使用的相同平台的第一个 TF 设备进行 TF 降级。之前,它使用 JAX 默认后端的第 0 个设备。
现在,许多
jax.numpy
函数的参数都被标记为仅位置参数,与 NumPy 一致。jnp.msort
现已弃用,遵循 numpy 1.24 中np.msort
的弃用。它将在未来的版本中移除,符合 API 兼容性 策略。它可以用jnp.sort(a, axis=0)
替换。
jaxlib 0.4.1 (2022 年 12 月 13 日)#
更改
已停止支持 Python 3.7,这符合 JAX 的 Python 和 NumPy 版本支持策略。
XLA_PYTHON_CLIENT_MEM_FRACTION=.XX
的行为已更改为分配总 GPU 内存的 XX%,而不是之前使用当前可用的 GPU 内存来计算预分配的行为。有关更多详细信息,请参阅 GPU 内存分配。已移除弃用的方法
.block_host_until_ready()
。请改用.block_until_ready()
。
jax 0.4.0 (2022 年 12 月 12 日)#
该版本已被撤回。
jaxlib 0.4.0 (2022 年 12 月 12 日)#
该版本已被撤回。
jax 0.3.25 (2022 年 11 月 15 日)#
更改
jax.numpy.linalg.pinv()
现在支持hermitian
选项。jax.scipy.linalg.hessenberg()
现在仅在 CPU 上受支持。需要 jaxlib > 0.3.24。添加了新函数
jax.lax.linalg.hessenberg()
、jax.lax.linalg.tridiagonal()
和jax.lax.linalg.householder_product()
。Householder 约简目前仅限于 CPU,三对角约简仅在 CPU 和 GPU 上受支持。现在,对于非方阵,
svd
和jax.numpy.linalg.pinv
的梯度计算得更加经济。
重大更改
删除了
jax_experimental_name_stack
配置选项。将字符串
axis_names
参数转换为jax.experimental.maps.Mesh
构造函数中的单例元组,而不是将字符串解包成一系列字符轴名称。
jaxlib 0.3.25 (2022 年 11 月 15 日)#
更改
添加了对 CPU 和 GPU 上的三对角约简的支持。
添加了对 CPU 上的上 Hessenberg 约简的支持。
错误
修复了一个错误,该错误导致 JAX 捕获的回溯中的帧在 Python 3.10+ 下被错误地映射到源代码行。
jax 0.3.24 (2022 年 11 月 4 日)#
更改
JAX 的导入速度应该更快了。我们现在延迟导入 scipy,这占了 JAX 导入时间的大部分。
设置环境变量
JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS=$N
可用于限制写入持久缓存的缓存条目数量。默认情况下,编译时间超过 1 秒的计算将被缓存。如果未指定顺序,则 TPU 上
pmap
使用的默认设备顺序现在与单进程作业的jax.devices()
匹配。以前,这两种排序方式不同,这可能导致不必要的复制或内存不足错误。要求排序方式一致可以简化问题。
重大更改
jax.numpy.gradient()
现在与jax.numpy
中的大多数其他函数一样,禁止传递列表或元组来代替数组 (#12958)jax.numpy.linalg
和jax.numpy.fft
中的函数现在统一要求输入为类数组:即不能使用列表和元组代替数组。作为 #7737 的一部分。
弃用
jax.sharding.MeshPspecSharding
已重命名为jax.sharding.NamedSharding
。jax.sharding.MeshPspecSharding
名称将在 3 个月内移除。
jaxlib 0.3.24 (2022 年 11 月 4 日)#
更改
缓冲区捐赠现在可以在 CPU 上工作。这可能会破坏在 CPU 上标记要捐赠的缓冲区的代码,但依赖于未实现捐赠。
jax 0.3.23 (2022 年 10 月 12 日)#
更改
更新 Colab TPU 驱动程序版本以适应新的 jaxlib 版本。
jax 0.3.22 (2022 年 10 月 11 日)#
更改
在 TPU 初始化中将
JAX_PLATFORMS=tpu,cpu
作为默认设置,因此如果无法初始化 TPU,JAX 将引发错误,而不是回退到 CPU。设置JAX_PLATFORMS=''
以覆盖此行为并自动选择可用的后端(原始默认值),或设置JAX_PLATFORMS=cpu
以始终使用 CPU,而不管 TPU 是否可用。
弃用
JAX v0.3.8 中弃用的一些测试实用程序现已从
jax.test_util
中移除。
jaxlib 0.3.22 (2022 年 10 月 11 日)#
jax 0.3.21 (2022 年 9 月 30 日)#
jax 0.3.20 (2022 年 9 月 28 日)#
jaxlib 0.3.20 (2022 年 9 月 28 日)#
jax 0.3.19 (2022 年 9 月 27 日)#
修复了所需的 jaxlib 版本。
jax 0.3.18 (2022 年 9 月 26 日)#
更改
提前降低和编译功能(在 #7733 中跟踪)已稳定并公开发布。请参阅 概述 和
jax.stages
的 API 文档。引入了
jax.Array
,用于 JAX 中数组类型的isinstance
检查和类型注释。请注意,这包括一些关于jax.numpy.ndarray
对 jax 内部对象的isinstance
工作方式的细微变化,因为jax.numpy.ndarray
现在是jax.Array
的简单别名。
重大更改
jax._src
现在不再导入到公共jax
命名空间中。这可能会破坏使用 JAX 内部组件的用户。jax.soft_pmap
已被删除。请使用pjit
或xmap
代替。jax.soft_pmap
未记录在案。如果已记录,则会提供弃用期。
jax 0.3.17 (2022年8月31日)#
错误
修复了
lax.pow
指数为零时的梯度中的极端情况问题(#12041)
重大更改
jax.checkpoint()
,也称为jax.remat()
,不再支持concrete
选项,遵循先前版本的弃用;请参阅JEP 11830。
更改
添加了
jax.pure_callback()
,它允许从编译函数(例如,用jax.jit
或jax.pmap
装饰的函数)回调到纯Python函数。
弃用
已弃用的
DeviceArray.tile()
方法已被删除。请使用jax.numpy.tile()
(#11944)。DeviceArray.to_py()
已被弃用。请改用np.asarray(x)
。
jax 0.3.16#
重大更改
根据弃用策略,已放弃对NumPy 1.19的支持。请升级到NumPy 1.20或更高版本。
更改
添加了
jax.debug
,其中包含用于运行时值调试的实用程序,例如jax.debug.print()
和jax.debug.breakpoint()
。为运行时值调试添加了新的文档。
弃用
jax.mask()
jax.shapecheck()
API已被删除。请参阅#11557。jax.experimental.loops
已被删除。请参阅#10278了解替代API。jax.tree_util.tree_multimap()
已被删除。自JAX版本0.3.5起,它已被弃用,并且jax.tree_util.tree_map()
是直接替换。已删除
jax.experimental.stax
;它长期以来一直是jax.example_libraries.stax
的弃用别名。已删除
jax.experimental.optimizers
;它长期以来一直是jax.example_libraries.optimizers
的弃用别名。jax.checkpoint()
,也称为jax.remat()
,具有一个新的默认启用的实现,这意味着旧的实现已被弃用;请参阅JEP 11830。
jax 0.3.15 (2022年7月22日)#
更改
JaxTestCase
和JaxTestLoader
已从jax.test_util
中删除。这些类自v0.3.1起已被弃用(#11248)。添加了
jax.scipy.gaussian_kde
(#11237)。JAX数组和内置集合(
dict
、list
、set
、tuple
)之间的二元运算现在在所有情况下都会引发TypeError
。以前某些情况(特别是相等和不等)会返回与NumPy中类似操作不一致的布尔标量(#11234)。一些作为顶级JAX包导入访问的
jax.tree_util
例程现在已被弃用,并将在未来的JAX版本中根据API兼容性策略删除。jax.treedef_is_leaf()
已被弃用,建议使用jax.tree_util.treedef_is_leaf()
jax.tree_flatten()
已被弃用,建议使用jax.tree_util.tree_flatten()
jax.tree_leaves()
已被弃用,建议使用jax.tree_util.tree_leaves()
jax.tree_structure()
已被弃用,建议使用jax.tree_util.tree_structure()
jax.tree_transpose()
已被弃用,建议使用jax.tree_util.tree_transpose()
jax.tree_unflatten()
已被弃用,建议使用jax.tree_util.tree_unflatten()
jax.scipy.linalg.solve()
的sym_pos
参数已被弃用,建议使用assume_a='pos'
,遵循scipy.linalg.solve()
中的类似弃用。
jaxlib 0.3.15 (2022年7月22日)#
jax 0.3.14 (2022年6月27日)#
重大更改
jax.experimental.compilation_cache.initialize_cache()
不再支持max_cache_size_ bytes
,并且不会将其作为输入。当平台初始化失败时,
JAX_PLATFORMS
现在会引发异常。
更改
修复了与NumPy 1.23的兼容性问题。
jax.numpy.linalg.slogdet()
现在接受一个可选的method
参数,该参数允许在基于LU分解的实现和基于QR分解的实现之间进行选择。jax.numpy.linalg.qr()
现在支持mode="raw"
。pickle
、copy.copy
和copy.deepcopy
在用于jax数组时现在具有更完整的支持(#10659)。特别是当用于
DeviceArray
时,pickle
和deepcopy
以前会返回np.ndarray
对象;现在返回DeviceArray
对象。对于deepcopy
,复制的数组与原始数组位于同一设备上。对于pickle
,反序列化的数组将位于默认设备上。在函数转换(即跟踪代码)中,
deepcopy
和copy
以前是无操作的。现在它们使用与DeviceArray.copy()
相同的机制。在跟踪数组上调用
pickle
现在会导致显式的ConcretizationTypeError
。
奇异值分解(SVD)和对称/厄米特特征分解的实现在TPU上应该会显著加快,尤其是在1000x1000或更大的矩阵上。两者现在都使用谱分治算法进行特征分解(QDWH-eig)。
jax.numpy.ldexp()
不再静默地将所有输入提升为float64,而是对于大小为int32或更小的整数输入将其提升为float32(#10921)。向
jax.profiler.start_trace()
和jax.profiler.start_trace()
添加了create_perfetto_link
选项。使用时,分析器将生成指向Perfetto UI的链接以查看跟踪。更改了
jax.profiler.start_server(...)()
的语义,使其全局存储保持活动状态,而不是要求用户保留对它的引用。添加了一个
python -m jax.collect_profile
脚本,以手动捕获程序跟踪作为TensorBoard UI的替代方案。添加了一个
jax.named_scope
上下文管理器,它将分析器元数据添加到Python程序中(类似于jax.named_call
)。在散射更新操作(即:attr:
jax.numpy.ndarray.at
)中,不安全的隐式dtype转换已弃用,现在会产生FutureWarning
。在将来的版本中,这将成为错误。不安全的隐式转换的一个示例是jnp.zeros(4, dtype=int).at[0].set(1.5)
,其中1.5
以前被静默截断为1
。jax.experimental.compilation_cache.initialize_cache()
现在支持gcs存储桶路径作为输入。jax.numpy.roots()
在系数具有前导零时,当strip_zeros=False
时,现在表现得更好(#11215)。
jaxlib 0.3.14 (2022年6月27日)#
-
x86-64 Mac轮子现在需要Mac OS 10.14(Mojave)或更高版本。Mac OS 10.14于2018年发布,因此这不是一个非常繁重的要求。
捆绑的NCCL版本已更新至2.12.12,修复了一些死锁问题。
Python flatbuffers包不再是jaxlib的依赖项。
jax 0.3.13 (2022年5月16日)#
jax 0.3.12 (2022年5月15日)#
jax 0.3.11 (2022年5月15日)#
更改
jax.lax.eigh()
现在接受一个可选的sort_eigenvalues
参数,允许用户选择退出TPU上的特征值排序。
弃用
jax.lax.linalg
中函数的非数组参数现在被标记为仅限关键字。作为向后兼容性步骤,按位置传递仅限关键字的参数会产生警告,但在将来的JAX版本中,按位置传递仅限关键字的参数将失败。但是,大多数用户应该更喜欢使用jax.numpy.linalg
。jax.scipy.linalg.polar_unitary()
(它是scipy API的JAX扩展)已弃用。请改用jax.scipy.linalg.polar()
。
jax 0.3.10 (2022年5月3日)#
jaxlib 0.3.10 (2022年5月3日)#
jax 0.3.9 (2022年5月2日)#
更改
添加了对GlobalDeviceArray完全异步检查点的支持。
jax 0.3.8 (2022年4月29日)#
更改
jax.numpy.linalg.svd()
在TPU上使用qdwh-svd求解器。jax.numpy.linalg.cond()
在TPU上现在接受复数输入。jax.numpy.linalg.pinv()
在TPU上现在接受复数输入。jax.numpy.linalg.matrix_rank()
在TPU上现在接受复数输入。jax.experimental.maps.mesh
已被删除。请使用jax.experimental.maps.Mesh
。有关更多信息,请参阅https://jax.ac.cn/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh。jax.scipy.linalg.qr()
现在在mode='r'
时返回长度为1的元组而不是原始数组,以匹配scipy.linalg.qr
的行为(#10452)。jax.numpy.take_along_axis()
现在接受一个可选的mode
参数,该参数指定越界索引的行为。默认情况下,对于越界索引将返回无效值(例如,NaN)。在以前版本的JAX中,无效索引被钳位到范围内。可以通过传递mode="clip"
来恢复以前的行为。jax.numpy.take()
现在默认为mode="fill"
,对于越界索引返回无效值(例如,NaN)。散射操作(例如
x.at[...].set(...)
)现在具有"drop"
语义。这本身对散射操作没有影响,但意味着当进行微分时,散射的梯度将为越界索引产生零余切。以前,越界索引被钳位到梯度的范围内,这在数学上是不正确的。jax.numpy.take_along_axis()
如果其索引不是整数类型,现在会引发TypeError
,这与numpy.take_along_axis()
的行为匹配。以前,非整数索引被静默转换为整数。jax.numpy.ravel_multi_index()
如果其dims
参数不是整数类型,现在会引发TypeError
,这与numpy.ravel_multi_index()
的行为匹配。以前,非整数dims
被静默转换为整数。jax.numpy.split()
如果其axis
参数不是整数类型,现在会引发TypeError
,这与numpy.split()
的行为匹配。以前,非整数axis
被静默转换为整数。jax.numpy.indices()
如果其维度不是整数类型,现在会引发TypeError
,这与numpy.indices()
的行为匹配。以前,非整数维度被静默转换为整数。jax.numpy.diag()
如果其k
参数不是整数类型,现在会引发TypeError
,这与numpy.diag()
的行为匹配。以前,非整数k
被静默转换为整数。
弃用
jax.test_util
中的许多函数和对象现已弃用,并在导入时引发警告。这包括cases_from_list
、check_close
、check_eq
、device_under_test
、format_shape_dtype_string
、rand_uniform
、skip_on_devices
、with_config
、xla_bridge
和_default_tolerance
(#10389)。这些以及之前已弃用的JaxTestCase
、JaxTestLoader
和BufferDonationTestCase
将在未来 JAX 版本中移除。大多数这些实用程序都可以替换为对标准 Python 和 NumPy 测试实用程序的调用,例如unittest
、absl.testing
、numpy.testing
等。JAX 特定的功能(如设备检查)可以通过使用公共 API(如jax.devices()
)来替换。许多已弃用的实用程序仍然存在于jax._src.test_util
中,但这些不是公共 API,因此在未来的版本中可能会更改或删除,恕不另行通知。
jax 0.3.7 (2022 年 4 月 15 日)#
更改
如果传递给
jax.numpy.take_along_axis()
的索引被广播,则修复了一个性能问题(#10281)。jax.scipy.special.expit()
和jax.scipy.special.logit()
现在要求其参数为标量或 JAX 数组。它们现在还将整数参数提升为浮点数。DeviceArray.tile()
方法已弃用,因为 NumPy 数组没有tile()
方法。作为替代,请使用jax.numpy.tile()
(#10266)。
jaxlib 0.3.7 (2022 年 4 月 15 日)#
更改
Linux 轮子现在按照
manylinux2014
标准构建,而不是manylinux2010
。
jax 0.3.6 (2022 年 4 月 12 日)#
jax 0.3.5 (2022 年 4 月 7 日)#
更改
添加了
jax.random.loggamma()
并改进了jax.random.beta()
和jax.random.dirichlet()
在小参数值下的行为(#9906)。私有的
lax_numpy
子模块不再在jax.numpy
命名空间中公开(#10029)。添加了数组创建例程
jax.numpy.frombuffer()
、jax.numpy.fromfunction()
和jax.numpy.fromstring()
(#10049)。DeviceArray.copy()
现在返回DeviceArray
而不是np.ndarray
(#10069)。jax.experimental.sharded_jit
已弃用,并将很快被移除。
弃用
jax.nn.normalize()
即将弃用。请改用jax.nn.standardize()
(#9899)。jax.tree_util.tree_multimap()
已弃用。请改用jax.tree_util.tree_map()
(#5746)。jax.experimental.sharded_jit
已弃用。请改用pjit
。
jaxlib 0.3.5 (2022 年 4 月 7 日)#
jax 0.3.4 (2022 年 3 月 18 日)#
jax 0.3.3 (2022 年 3 月 17 日)#
jax 0.3.2 (2022 年 3 月 16 日)#
更改
函数
jax.ops.index_update
、jax.ops.index_add
(在 0.2.22 中已弃用)已被移除。请改用 JAX 数组上的.at
属性,例如x.at[idx].set(y)
。将
jax.experimental.ann.approx_*_k
移动到jax.lax
中。这些函数是jax.lax.top_k
的优化替代方案。jax.numpy.broadcast_arrays()
和jax.numpy.broadcast_to()
现在需要标量或类数组输入,如果传递列表则会失败(#7737 的一部分)。标准 jax[tpu] 安装现在可以与 Cloud TPU v4 VM 一起使用。
pjit
现在可以在 CPU 上运行(除了之前的 TPU 和 GPU 支持)。
jaxlib 0.3.2 (2022 年 3 月 16 日)#
更改
XlaComputation.as_hlo_text()
现在支持通过传递布尔标志print_large_constants=True
来打印大型常量。
弃用
JAX 数组上的
.block_host_until_ready()
方法已弃用。请改用.block_until_ready()
。
jax 0.3.1 (2022 年 2 月 18 日)#
更改
jax.test_util.JaxTestCase
和jax.test_util.JaxTestLoader
现已弃用。建议的替代方法是直接使用parametrized.TestCase
。对于依赖于自定义断言(如JaxTestCase.assertAllClose()
)的测试,建议的替代方法是使用标准 NumPy 测试实用程序(如numpy.testing.assert_allclose()
),它们可以直接与 JAX 数组一起使用(#9620)。jax.test_util.JaxTestCase
现在默认设置jax_numpy_rank_promotion='raise'
(#9562)。要恢复以前的行为,请使用新的jax.test_util.with_config
装饰器。@jtu.with_config(jax_numpy_rank_promotion='allow') class MyTestCase(jtu.JaxTestCase): ...
添加了
jax.scipy.linalg.schur()
、jax.scipy.linalg.sqrtm()
、jax.scipy.signal.csd()
、jax.scipy.signal.stft()
和jax.scipy.signal.welch()
。
jax 0.3.0 (2022年2月10日)#
jaxlib 0.3.0 (2022年2月10日)#
更改
现在需要 Bazel 5.0.0 来构建 jaxlib。
jaxlib 版本已更新至 0.3.0。请参阅 设计文档 以获取解释。
jax 0.2.28 (2022年2月1日)#
-
jax.jit(f).lower(...).compiler_ir()
现在默认为 MHLO 方言,如果未传递dialect=
。jax.jit(f).lower(...).compiler_ir(dialect='mhlo')
现在返回 MLIRir.Module
对象,而不是其字符串表示形式。
jaxlib 0.1.76 (2022年1月27日)#
新功能
包含为 Nvidia 计算能力 8.0 GPU(例如 A100)预编译的 SASS。删除了为计算能力 6.1 预编译的 SASS,以免增加计算能力的数量:具有计算能力 6.1 的 GPU 可以使用 6.0 SASS。
使用 jaxlib 0.1.76,JAX 默认使用 MHLO MLIR 方言作为其主要目标编译器 IR。
重大更改
根据 弃用策略,已放弃对 NumPy 1.18 的支持。请升级到受支持的 NumPy 版本。
错误修复
修复了通过不同路径构建的明显相同的 pytreedef 对象无法比较为相等的问题(#9066)。
JAX jit 缓存需要两个静态参数具有相同的类型才能命中缓存(#9311)。
jax 0.2.27 (2022年1月18日)#
重大更改
根据 弃用策略,已放弃对 NumPy 1.18 的支持。请升级到受支持的 NumPy 版本。
主机回调原语已简化,以删除 hcb.id_tap 和 id_print 的特殊自动微分处理。从现在开始,只提取基本值。可以通过设置
JAX_HOST_CALLBACK_AD_TRANSFORMS
环境变量或--jax_host_callback_ad_transforms
标志来获得旧行为(有限时间)。此外,添加了有关如何使用 JAX 自定义 AD API 实现旧行为的文档(#8678)。排序现在与 NumPy 的行为匹配,无论位表示如何,对于
0.0
和NaN
都是如此。特别是,0.0
和-0.0
现在被视为等效,而以前-0.0
被视为小于0.0
。此外,所有NaN
表示现在都被视为等效,并排序到数组的末尾。以前,负NaN
值被排序到数组的前面,并且具有不同内部位表示的NaN
值不被视为等效,并且根据这些位模式进行排序(#9178)。jax.numpy.unique()
现在处理NaN
值的方式与 NumPy 1.21 及更高版本中的np.unique
相同:唯一化输出中最多会出现一个NaN
值(#9184)。
错误修复
host_callback 现在支持 ad_checkpoint.checkpoint(#8907)。
新功能
添加
jax.block_until_ready
({jax-issue}`#8941`)添加了一个新的调试标志/环境变量
JAX_DUMP_IR_TO=/path
。如果设置,JAX 会将其为每个计算生成的 MHLO/HLO IR 导出到给定路径下的文件。将
jax.ensure_compile_time_eval
添加到公共 api(#7987)。jax2tf 现在支持一个标志 jax2tf_associative_scan_reductions 来更改关联归约(例如 jnp.cumsum)的降低方式,使其在 CPU 和 GPU 上的行为类似于 JAX(使用关联扫描)。有关更多详细信息,请参阅 jax2tf 自述文件(#9189)。
jaxlib 0.1.75 (2021年12月8日)#
新功能
支持 python 3.10。
jax 0.2.26 (2021年12月8日)#
jaxlib 0.1.74 (2021年11月17日)#
启用了 GPU 之间的点对点复制。以前,GPU 复制通过主机进行反弹,这通常比较慢。
添加了实验性的 MLIR Python 绑定,供 JAX 使用。
jax 0.2.25 (2021年11月10日)#
jax 0.2.24 (2021年10月19日)#
jaxlib 0.1.73 (2021年10月18日)#
jaxlib GPU
cuda11
轮现在支持多个 cuDNN 版本。cuDNN 8.2 或更高版本。如果您的 cuDNN 安装足够新,我们建议使用 cuDNN 8.2 轮,因为它支持其他功能。
cuDNN 8.0.5 或更高版本。
重大更改
GPU jaxlib 的安装命令如下
pip install --upgrade pip # Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer. pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html # Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer. pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html # Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer. pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
jax 0.2.22 (2021年10月12日)#
重大更改
jax.pmap
的静态参数现在必须是可散列的。不可散列的静态参数长期以来一直不允许在
jax.jit
上使用,但它们仍然允许在jax.pmap
上使用;jax.pmap
使用对象标识来比较不可散列的静态参数。此行为是一个陷阱,因为使用对象标识比较参数会导致每次对象标识更改时都重新编译。相反,我们现在禁止不可散列的参数:如果
jax.pmap
的用户希望通过对象标识来比较静态参数,他们可以在其对象上定义__hash__
和__eq__
方法来执行此操作,或者将其对象包装在一个具有这些操作的对象中,这些操作具有对象标识语义。另一种选择是使用functools.partial
将不可散列的静态参数封装到函数对象中。jax.util.partial
是一个意外导出的内容,现已删除。请改用 Python 标准库中的functools.partial
。
弃用
函数
jax.ops.index_update
、jax.ops.index_add
等已弃用,将在未来的 JAX 版本中删除。请改用 JAX 数组上的.at
属性,例如x.at[idx].set(y)
。目前,这些函数会产生DeprecationWarning
。
新功能
当使用 jaxlib 0.1.72 或更高版本时,一个优化的 C++ 代码路径改进了
pmap
的分派时间,现在是默认设置。可以使用--experimental_cpp_pmap
标志(或JAX_CPP_PMAP
环境变量)禁用此功能。jax.numpy.unique
现在支持可选的fill_value
参数(#8121)
jaxlib 0.1.72 (2021年10月12日)#
重大更改
已放弃对 CUDA 10.2 和 CUDA 10.1 的支持。Jaxlib 现在支持 CUDA 11.1+。
错误修复
修复了 https://github.com/google/jax/issues/7461,该问题由于 XLA 编译器内部的缓冲区别名不正确,导致所有平台上的输出错误。
jax 0.2.21 (2021年9月23日)#
重大更改
jax.api
已删除。作为jax.api.*
可用的函数是jax.*
中函数的别名;请改用jax.*
中的函数。jax.partial
和jax.lax.partial
是意外导出的内容,现已删除。请改用 Python 标准库中的functools.partial
。布尔标量索引现在会引发
TypeError
;以前这会静默地返回错误的结果(#7925)。更多
jax.numpy
函数现在需要类数组的输入,如果传入列表则会报错(#7747 #7802 #7907)。有关此更改背后的原因,请参阅#7737。在
jax.jit
等变换内部,jax.numpy.array
始终将其生成的数组分阶段到跟踪的计算中。以前,jax.numpy.array
有时会生成一个设备上的数组,即使在jax.jit
装饰器下也是如此。此更改可能会破坏使用JAX数组执行必须静态知道的形状或索引计算的代码;解决方法是使用经典的NumPy数组执行此类计算。jnp.ndarray
现在是JAX数组的真正基类。特别是,这意味着对于标准的numpy数组x
,isinstance(x, jnp.ndarray)
现在将返回False
(#7927)。
新功能
添加了
jax.numpy.insert()
实现(#7936)。
jax 0.2.20 (2021年9月2日)#
jaxlib 0.1.71 (2021年9月1日)#
重大更改
已放弃对CUDA 11.0和CUDA 10.1的支持。Jaxlib现在支持CUDA 10.2和CUDA 11.1及更高版本。
jax 0.2.19 (2021年8月12日)#
重大更改
根据弃用策略,已放弃对NumPy 1.17的支持。请升级到受支持的NumPy版本。
已在JAX数组上的一些运算符的实现周围添加了
jit
装饰器。这加快了常见运算符(如+
)的分派时间。此更改对于大多数用户来说应该是透明的。但是,有一个已知的行为更改,即大型整数常量在直接传递给JAX运算符时现在可能会产生错误(例如,
x + 2**40
)。解决方法是将常量转换为显式类型(例如,np.float64(2**40)
)。
新功能
改进了jax2tf中形状多态性的支持,用于需要在数组计算中使用维度大小的运算,例如
jnp.mean
。(#7317)
错误修复
上一个版本中的一些泄漏跟踪错误(#7613)
jaxlib 0.1.70 (2021年8月9日)#
jax 0.2.18 (2021年7月21日)#
重大更改
根据弃用策略,已放弃对Python 3.6的支持。请升级到受支持的Python版本。
jaxlib的最小版本现在为0.1.69。
jax.dlpack.from_dlpack()
的backend
参数已删除。
新功能
添加了极分解(
jax.scipy.linalg.polar()
)。
错误修复
加强了对lax.argmin和lax.argmax的检查,以确保它们不会与无效的
axis
值或空约简维度一起使用。(#7196)
jaxlib 0.1.69 (2021年7月9日)#
修复了导致TFRT CPU后端结果不正确的一些错误。
jax 0.2.17 (2021年7月9日)#
错误修复
对于jaxlib <= 0.1.68,默认使用旧的“stream_executor”CPU运行时,以解决#7229,该问题由于并发问题导致CPU上的输出错误。
新功能
新的SciPy函数
jax.scipy.special.sph_harm()
。反向模式自动微分函数(
jax.grad()
、jax.value_and_grad()
、jax.vjp()
和jax.linear_transpose()
)支持一个参数,该参数指示如果在正向传递中广播了哪些命名轴,则应在反向传递中对它们求和。这使得能够在映射内部以非按示例的方式使用这些API(最初仅限于jax.experimental.maps.xmap()
)(#6950)。
jax 0.2.16 (2021年6月23日)#
jax 0.2.15 (2021年6月23日)#
jaxlib 0.1.68 (2021年6月23日)#
错误修复
修复了TFRT CPU后端在将TPU缓冲区传输到CPU时出现NaN的错误。
jax 0.2.14 (2021年6月10日)#
新功能
jax2tf.convert()
现在支持pjit
和sharded_jit
。一个新的配置选项JAX_TRACEBACK_FILTERING控制JAX如何过滤回溯。
一个新的使用
__tracebackhide__
的回溯过滤模式现在在足够新的IPython版本中默认启用。jax2tf.convert()
支持形状多态性,即使未知维度用于算术运算,例如jnp.reshape(-1)
(#6827)。jax2tf.convert()
在TF操作中生成带有位置信息的自定义属性。jax2tf之后XLA生成的代码与JAX/XLA具有相同的位置信息。新的SciPy函数
jax.scipy.special.lpmn()
。
错误修复
jaxlib 0.1.67 (2021年5月17日)#
jaxlib 0.1.66 (2021年5月11日)#
新功能
CUDA 11.1轮子现在在所有CUDA 11版本11.1或更高版本上都受支持。
NVidia现在承诺从CUDA 11.1开始,CUDA次要版本之间兼容。这意味着JAX可以发布一个与CUDA 11.2和11.3兼容的CUDA 11.1轮子。
不再有针对CUDA 11.2(或更高版本)的单独jaxlib版本;对于这些版本,请使用CUDA 11.1轮子(cuda111)。
Jaxlib现在在CUDA轮子中捆绑了
libdevice.10.bc
。无需将JAX指向CUDA安装以查找此文件。添加了对
jit()
实现的静态关键字参数的自动支持。添加了对预转换异常跟踪的支持。
对从
jit()
转换的计算中修剪未使用参数的初步支持。修剪仍在进行中。改进了
PyTreeDef
对象的字符串表示形式。添加了对XLA的可变ReduceWindow的支持。
错误修复
修复了在将大量参数传递给计算时远程云 TPU 支持中的错误。
修复了一个错误,该错误导致
jit()
转换后的函数未触发 JAX 垃圾回收。
jax 0.2.13 (2021年5月3日)#
新功能
与 jaxlib 0.1.66 结合使用时,
jax.jit()
现在支持静态关键字参数。添加了一个新的static_argnames
选项来指定关键字参数为静态。jax.nonzero()
有一个新的可选size
参数,允许它在jit
中使用 (#6501)jax.numpy.unique()
现在支持axis
参数 (#6532).jax.experimental.host_callback.call()
现在支持pjit.pjit
(#6569).添加了
jax.scipy.linalg.eigh_tridiagonal()
,用于计算三对角矩阵的特征值。目前仅支持特征值。异常中已过滤和未过滤堆栈跟踪的顺序已更改。现在已过滤附加到从 JAX 转换的代码中抛出的异常的回溯,其中
UnfilteredStackTrace
异常包含原始跟踪作为已过滤异常的__cause__
。已过滤的堆栈跟踪现在也适用于 Python 3.6。如果由反向模式自动微分转换的代码抛出异常,JAX 现在尝试将
JaxStackTraceBeforeTransformation
对象作为异常的__cause__
附加,该对象包含在正向传递中创建原始操作的堆栈跟踪。需要 jaxlib 0.1.66。
重大更改
以下函数名称已更改。仍然存在别名,因此这不会破坏现有代码,但最终会删除别名,因此请更改您的代码。
host_id
–>process_index()
host_count
–>process_count()
host_ids
–>range(jax.process_count())
类似地,
local_devices()
的参数已从host_id
重命名为process_index
。jax.jit()
的参数(函数除外)现在标记为仅限关键字。此更改是为了防止在向jit
添加参数时意外中断。
错误修复
jaxlib 0.1.65 (2021年4月7日)#
jax 0.2.12 (2021年4月1日)#
新功能
新的分析 API:
jax.profiler.start_trace()
、jax.profiler.stop_trace()
和jax.profiler.trace()
jax.lax.reduce()
现在是可微的。
重大更改
jaxlib 的最低版本现在为 0.1.64。
一些分析器 API 名称已更改。仍然存在别名,因此这不会破坏现有代码,但最终会删除别名,因此请更改您的代码。
TraceContext
–>TraceAnnotation()
StepTraceContext
–>StepTraceAnnotation()
trace_function
–>annotate_function()
无法再禁用全阶段。有关更多信息,请参阅 全阶段。
大于最大
int64
值的 Python 整数现在将在所有情况下导致溢出,而不是在某些情况下静默转换为uint64
(#6047).在 X64 模式之外,超出
int32
可表示范围的 Python 整数现在将导致OverflowError
,而不是静默截断其值。
错误修复
host_callback
现在支持参数和结果中的空数组 (#6262).jax.random.randint()
会剪辑而不是环绕超出范围的限制,并且现在可以生成指定 dtype 的完整范围内的整数 (#5868)
jax 0.2.11 (2021年3月23日)#
新功能
错误修复
重大更改
jaxlib 的最低版本现在为 0.1.62。
jaxlib 0.1.64 (2021年3月18日)#
jaxlib 0.1.63 (2021年3月17日)#
jax 0.2.10 (2021年3月5日)#
新功能
jax.scipy.stats.chi2()
现在可作为具有 logpdf 和 pdf 方法的分布使用。jax.scipy.stats.betabinom()
现在可作为具有 logpmf 和 pmf 方法的分布使用。添加了
jax.experimental.jax2tf.call_tf()
以从 JAX 调用 TensorFlow 函数 (#5627) 和 README).扩展了
lax.pad
的批处理规则以支持填充值的批处理。
错误修复
jax.numpy.take()
正确处理负索引 (#5768)
重大更改
JAX 的提升规则已调整,使提升更一致且对 JIT 不变。特别是,二元运算现在可以在适当的时候导致弱类型的值。更改的主要用户可见效果是某些操作产生的输出精度不同于以前;例如,表达式
jnp.bfloat16(1) + 0.1 * jnp.arange(10)
以前返回一个float64
数组,现在返回一个bfloat16
数组。JAX 的类型提升行为在 类型提升语义 中进行了描述。jax.numpy.linspace()
现在计算整数的值的下限,即向 -inf 而不是 0 四舍五入。此更改是为了匹配 NumPy 1.20.0。jax.numpy.i0()
不再接受复数。以前该函数计算复数参数的绝对值。此更改是为了匹配 NumPy 1.20.0 的语义。一些
jax.numpy
函数不再接受元组或列表作为数组参数的替代:jax.numpy.pad()
,:funcjax.numpy.ravel
,jax.numpy.repeat()
,jax.numpy.reshape()
。一般来说,jax.numpy
函数应该与标量或数组参数一起使用。
jaxlib 0.1.62 (2021年3月9日)#
新功能
jaxlib 轮子现在默认构建为在 x86-64 机器上需要 AVX 指令。如果您想在不支持 AVX 的机器上使用 JAX,您可以使用
--target_cpu_features
标志从源代码构建 jaxlib 到build.py
。--target_cpu_features
也替换了--enable_march_native
。
jaxlib 0.1.61 (2021年2月12日)#
jaxlib 0.1.60 (2021年2月3日)#
错误修复
修复了将 CPU DeviceArray 转换为 NumPy 数组时的内存泄漏。内存泄漏存在于 jaxlib 版本 0.1.58 和 0.1.59 中。
bool
、int8
和uint8
现在被认为可以安全地转换为bfloat16
NumPy 扩展类型。
jax 0.2.9 (2021年1月26日)#
新功能
扩展了
jax.experimental.loops
模块,使其支持 pytree。改进了错误检查和错误消息。添加了
jax.experimental.enable_x64()
和jax.experimental.disable_x64()
。这些是上下文管理器,允许在会话中临时启用/禁用 X64 模式。
重大更改
jax.ops.segment_sum()
现在会丢弃超出范围的段 ID,而不是将其包裹到段 ID 空间中。这是出于性能原因。
jaxlib 0.1.59 (2021年1月15日)#
jax 0.2.8 (2021年1月12日)#
新功能
添加了
jax.closure_convert()
用于高阶自定义导数函数。(#5244)添加了
jax.experimental.host_callback.call()
以在主机上调用自定义 Python 函数并将结果返回到设备计算。(#5243)
错误修复
重大更改
jax.numpy.pad
现在接受关键字参数。位置参数constant_values
已被移除。此外,传递不支持的关键字参数会引发错误。针对
jax.experimental.host_callback.id_tap()
的更改(#5243)删除了对
jax.experimental.host_callback.id_tap()
的kwargs
的支持。(此支持已弃用几个月了。)更改了
jax.experimental.host_callback.id_print()
元组的打印方式,使用 ‘(’ 而不是 ‘[’。在存在 JVP 的情况下更改了
jax.experimental.host_callback.id_print()
,以打印一对原始值和切线值。以前,对原始值和切线值分别进行了两次打印操作。host_callback.outfeed_receiver
已被移除(它不是必需的,并且在几个月前已弃用)。
新功能
用于调试
inf
的新标志,类似于NaN
的标志(#5224)。
jax 0.2.7 (2020年12月4日)#
新功能
添加
jax.device_put_replicated
向
jax.experimental.sharded_jit
添加多主机支持添加对
jax.numpy.linalg.eig
计算的特征值的微分支持添加对在 Windows 平台上构建的支持
添加对
jax.pmap
中的一般 in_axes 和 out_axes 的支持添加对
jax.numpy.linalg.slogdet
的复数支持
错误修复
修复了零点处
jax.numpy.sinc
的高于二阶导数修复了一些关于转置规则中符号零的难以触发的错误
重大更改
jax.experimental.optix
已被删除,取而代之的是独立的optax
Python 包。使用非元组序列对 JAX 数组进行索引现在会引发
TypeError
。此类索引自 Numpy v1.16 以来已弃用,自 JAX v0.2.4 以来也已弃用。参见 #4564。
jax 0.2.6 (2020年11月18日)#
新功能
添加对 jax.experimental.jax2tf 转换器的形状多态跟踪的支持。参见 README.md。
重大更改清理
对于 jax.jit 和 xla_computation,对不可散列的静态参数引发错误。参见 cb48f42。
改进类型提升行为的一致性(#4744)
将复数 Python 标量添加到 JAX 浮点数会尊重 JAX 浮点数的精度。例如,
jnp.float32(1) + 1j
现在返回complex64
,而以前返回complex128
。涉及 uint64、带符号整数和第三种类型的 3 个或更多项的类型提升结果现在与参数顺序无关。例如:
jnp.result_type(jnp.uint64, jnp.int64, jnp.float16)
和jnp.result_type(jnp.float16, jnp.uint64, jnp.int64)
都返回float16
,而以前第一个返回float64
,第二个返回float16
。
(未记录的)
jax.lax_linalg
线性代数模块的内容现在公开为jax.lax.linalg
。jax.random.PRNGKey
现在在 JIT 编译内外产生相同的结果(#4877)。这需要在一些特定情况下更改给定种子的结果使用
jax_enable_x64=False
时,作为 Python 整数传递的负种子现在在 JIT 模式之外返回不同的结果。例如,jax.random.PRNGKey(-1)
以前返回[4294967295, 4294967295]
,现在返回[0, 4294967295]
。这与 JIT 中的行为一致。在 JIT 外部,超出
int64
可表示范围的种子现在会产生OverflowError
而不是TypeError
。这与 JIT 中的行为一致。
要恢复以前为
jax_enable_x64=False
在 JIT 外部使用负整数返回的密钥,您可以使用key = random.PRNGKey(-1).at[0].set(0xFFFFFFFF)
当尝试访问已被删除的 DeviceArray 的值时,DeviceArray 现在会引发
RuntimeError
而不是ValueError
。
jaxlib 0.1.58 (2021年1月12日左右)#
修复了一个错误,该错误导致 JAX 有时返回平台特定的类型(例如,
np.cint
),而不是标准类型(例如,np.int32
)。(#4903)修复了在常量折叠某些 int16 操作时发生的崩溃。(#4971)
向
pytree.flatten()
添加了一个is_leaf
谓词。
jaxlib 0.1.57 (2020年11月12日)#
修复了 GPU 轮子中的许多 manylinux2010 兼容性问题。
将 CPU FFT 实现从 Eigen 切换到 PocketFFT。
修复了一个错误,该错误会导致 bfloat16 值的哈希值未正确初始化并可能发生变化(#4651)。
添加对将数组传递到 DLPack 时保留所有权的支持(#4636)。
修复了大小大于 128 但不是 128 的倍数的分批三角求解的错误。
修复了在多个 GPU 上执行并发 FFT 时出现的错误(#3518)。
修复了探查器中工具丢失的错误(#4427)。
放弃对 CUDA 10.0 的支持。
jax 0.2.5 (2020年10月27日)#
改进
确保
check_jaxpr
不执行浮点运算。请参阅#4650。扩展了 jax2tf 转换的 JAX 原语集。请参阅primitives_with_limited_support.md。
jax 0.2.4 (2020年10月19日)#
jaxlib 0.1.56 (2020年10月14日)#
jax 0.2.3 (2020年10月14日)#
如此快地发布另一个版本的原因是我们需要暂时回滚一个新的 jit 快速路径,同时我们正在调查性能下降的问题。
jax 0.2.2 (2020年10月13日)#
jax 0.2.1 (2020年10月6日)#
改进
作为全阶段编译的优势,即使
jax.experimental.host_callback.id_print()
/jax.experimental.host_callback.id_tap()
的结果未在计算中使用,主机回调函数也会(按程序顺序)执行。
jax (0.2.0) (2020年9月23日)#
jax (0.1.77) (2020年9月15日)#
重大更改
为
jax.experimental.host_callback.id_tap()
提供新的简化接口(#4101)
jaxlib 0.1.55 (2020年9月8日)#
更新 XLA
修复 DLPackManagedTensorToBuffer 中的错误(#4196)
jax 0.1.76 (2020年9月8日)#
jax 0.1.75 (2020年7月30日)#
错误修复
使 jnp.abs() 能够处理无符号输入(#3914)
改进
添加了“全阶段编译”行为,但隐藏在标志后面,默认情况下禁用(#3370)
jax 0.1.74 (2020年7月29日)#
新功能
BFGS(#3101)
对半精度算术的 TPU 支持(#3878)
错误修复
防止某些意外的数据类型警告(#3874)
修复了自定义导数中的多线程错误(#3845、#3869)
改进
更快的 searchsorted 实现(#3873)
jax.numpy 排序算法的测试覆盖率更好(#3836)
jaxlib 0.1.52 (2020年7月22日)#
更新 XLA。
jax 0.1.73 (2020年7月22日)#
jaxlib 的最低版本现在为 0.1.51。
新功能
jax.image.resize.(#3703)
hfft 和 ihfft(#3664)
jax.numpy.intersect1d(#3726)
jax.numpy.lexsort(#3812)
lax.scan
和scan
原语支持一个unroll
参数,用于在降低到 XLA 时进行循环展开(#3738)。
错误修复
修复了归约重复轴错误(#3618)
修复了 lax.pad 的形状规则,用于大小为 0 的输入维度。(#3608)
使 psum 转置能够处理零余切(#3653)
修复了在对大小为 0 的轴上的 reduce-prod 进行 JVP 时出现的形状错误。(#3729)
支持通过 jax.lax.all_to_all 进行微分(#3733)
解决 jax.scipy.special.zeta 中的 NaN 问题(#3777)
改进
对 jax2tf 进行了许多改进
使用单遍变长归约重新实现了 argmin/argmax。(#3611)
默认情况下启用 XLA SPMD 分区。(#3151)
添加对 0d 转置卷积的支持(#3643)
使 LU 梯度能够处理低秩矩阵(#3610)
在 jet 中支持 multiple_results 和自定义 JVP(#3657)
将归约窗口填充泛化为支持 (lo, hi) 对。(#3728)
在 CPU 和 GPU 上实现复数卷积。(#3735)
使 jnp.take 能够处理空数组的空切片。(#3751)
放宽了 dot_general 的维度排序规则。(#3778)
为 GPU 启用缓冲区捐赠。(#3800)
添加对归约窗口操作的基本扩张和窗口扩张的支持…(#3803)
jaxlib 0.1.51 (2020年7月2日)#
更新 XLA。
添加对主机回调的新运行时支持。
jax 0.1.72 (2020年6月28日)#
jax 0.1.71 (2020年6月25日)#
jaxlib 0.1.50 (2020年6月25日)#
添加对 CUDA 11.0 的支持。
放弃对 CUDA 9.2 的支持(我们只维护对最后四个 CUDA 版本的支持)。
更新 XLA。
jaxlib 0.1.49 (2020年6月19日)#
错误修复
修复了可能导致编译速度慢的构建问题(tensorflow/tensorflow)
jaxlib 0.1.48 (2020年6月12日)#
新功能
添加对快速回溯收集的支持。
添加对设备上堆内存分析的初步支持。
为
bfloat16
类型实现了np.nextafter
。在 CPU 和 GPU 上为 FFT 提供 Complex128 支持。
错误修复
提高了 GPU 上 float64
tanh
的精度。GPU 上的 float64 散射速度快得多。
CPU 上的复数矩阵乘法应该快得多。
CPU 上的稳定排序现在应该真正稳定了。
修复了 CPU 后端中的并发错误。
jax 0.1.70 (2020年6月8日)#
jax 0.1.69 (2020年6月3日)#
jax 0.1.68 (2020年5月21日)#
新功能
lax.cond()
支持单操作数形式,作为两个分支的参数#2993。
显著变化
为
jax.experimental.host_callback.id_tap()
原语的transforms
关键字更改了格式#3132。
jax 0.1.67 (2020年5月12日)#
新功能
使用
axis_index_groups
支持对 pmapped 轴的子集进行归约#2382。对从编译代码打印和调用主机端 Python 函数提供了实验性支持。请参阅id_print 和 id_tap(#3006)。
显著变化
从
jax.numpy
导出的名称的可见性已收紧。这可能会破坏以前利用意外导出的名称的代码。
jaxlib 0.1.47 (2020年5月8日)#
修复了输出馈送的崩溃问题。
jax 0.1.66 (2020年5月5日)#
jaxlib 0.1.46 (2020年5月5日)#
修复了 macOS 上线性代数函数的崩溃问题(#432)。
修复了当操作系统或虚拟机禁用 AVX512 指令时导致的非法指令崩溃问题(#2906)。
jax 0.1.65 (2020年4月30日)#
jaxlib 0.1.45 (2020年4月21日)#
修复了段错误:#2755
将 is_stable 选项从 Sort HLO 传递到 Python。
jax 0.1.64 (2020年4月21日)#
新功能
添加了函数式索引更新的语法糖#2684。
为
jax.experimental.jet()
添加更多基本规则。
错误修复
改进错误信息
改进
lax.while_loop()
反向模式微分的错误消息 #2129。
jaxlib 0.1.44 (2020年4月16日)#
修复了一个错误,该错误会导致如果存在多个不同型号的GPU,JAX只会编译适合第一个GPU的程序。
batch_group_count
卷积的错误修复。添加了更多GPU版本的预编译SASS,以避免启动PTX编译挂起。
jax 0.1.63 (2020年4月12日)#
添加了来自 #2026 的
jax.custom_jvp
和jax.custom_vjp
,请参阅 教程笔记本。弃用jax.custom_transforms
并将其从文档中删除(尽管它仍然有效)。添加
scipy.sparse.linalg.cg
#2566。更改了Tracer的打印方式,以显示更多有用的调试信息 #2591。
使
jax.numpy.isclose
正确处理nan
和inf
#2501。为
jax.experimental.jet
添加了几个新规则 #2537。修复了当未提供
scale
/center
时jax.experimental.stax.BatchNorm
的问题。修复了
jax.numpy.einsum
中的一些广播缺失情况 #2512。使用并行前缀扫描实现
jax.numpy.cumsum
和jax.numpy.cumprod
#2596,并使reduce_prod
可微分到任意阶 #2597。将
batch_group_count
添加到conv_general_dilated
#2635。为
test_util.check_grads
添加文档字符串 #2656。添加
callback_transform
#2665。实现
rollaxis
、convolve
/correlate
1d & 2d、copysign
、trunc
、roots
以及quantile
/percentile
插值选项。
jaxlib 0.1.43 (2020年3月31日)#
修复了Resnet-50在GPU上的性能下降。
jax 0.1.62 (2020年3月21日)#
JAX已放弃对Python 3.5的支持。请升级到Python 3.6或更高版本。
删除了内部函数
lax._safe_mul
,该函数实现了约定0. * nan == 0.
。此更改意味着某些程序在进行微分时会生成NaN,而之前会生成正确的值,尽管它确保了对于其他程序生成NaN而不是静默的错误结果。有关详细信息,请参见#2447和#1052。添加了一个
all_gather
并行便捷函数。核心代码中添加了更多类型注释。
jaxlib 0.1.42 (2020年3月19日)#
jaxlib 0.1.41由于API不兼容而破坏了云TPU支持。此版本再次修复了它。
JAX已放弃对Python 3.5的支持。请升级到Python 3.6或更高版本。
jax 0.1.61 (2020年3月17日)#
修复了Python 3.5支持。这将是最后一个支持Python 3.5的JAX或jaxlib版本。
jax 0.1.60 (2020年3月17日)#
新功能
jax.pmap()
具有static_broadcast_argnums
参数,允许用户指定应视为编译时常量并应广播到所有设备的参数。它的工作原理类似于jax.jit()
中的static_argnums
。改进了当Tracer错误地保存在全局状态中的错误消息。
添加了
jax.nn.one_hot()
实用函数。添加了
jax.experimental.jet
用于实现指数级更快的更高阶自动微分。为
jax.lax.broadcast_in_dim()
的参数添加了更多正确性检查。
jaxlib的最低版本现在是0.1.41。
jaxlib 0.1.40 (2020年3月4日)#
在Jaxlib中添加了对TensorFlow Profiler的实验性支持,允许从TensorBoard跟踪CPU和GPU计算。
包含了通过NCCL通信的多主机GPU计算的原型支持。
提高了GPU上NCCL集体通信的性能。
添加了TopK、CustomCallWithoutLayout、CustomCallWithLayout、IGammaGradA和RandomGamma实现。
支持在XLA编译时已知的设备分配。
jax 0.1.59 (2020年2月11日)#
重大更改
jaxlib的最低版本现在是0.1.38。
通过删除
Jaxpr.freevars
和Jaxpr.bound_subjaxprs
简化了Jaxpr
。调用原语(xla_call
、xla_pmap
、sharded_call
和remat_call
)获得一个新参数call_jaxpr
,该参数具有完全封闭(无constvars
)的jaxpr。此外,还在原语中添加了一个新字段call_primitive
。
新功能
对
lax.cond
进行反向模式自动微分(例如grad
),使其现在在两种模式下都可微分(#2091)JAX现在支持DLPack,它允许以零拷贝方式与其他库(如PyTorch)共享CPU和GPU数组。
JAX GPU DeviceArrays现在支持
__cuda_array_interface__
,这是另一个与其他库(如CuPy和Numba)共享GPU数组的零拷贝协议。JAX CPU设备缓冲区现在实现了Python缓冲区协议,允许在JAX和NumPy之间进行零拷贝缓冲区共享。
添加了JAX_SKIP_SLOW_TESTS环境变量以跳过已知为缓慢的测试。
jaxlib 0.1.39 (2020年2月11日)#
更新XLA。
jaxlib 0.1.38 (2020年1月29日)#
不再支持CUDA 9.0。
现在默认构建CUDA 10.2轮子。
jax 0.1.58 (2020年1月28日)#
重大更改
JAX已放弃Python 2支持,因为Python 2已于2020年1月1日停止使用。请更新到Python 3.5或更高版本。
新功能
while循环的前向模式自动微分(
jvp
)(#1980)
新的NumPy和SciPy函数
GPU上的批量Cholesky分解现在使用更有效的批量内核。
值得注意的错误修复#
随着Python 3的升级,JAX不再依赖于
fastcache
,这应该有助于安装。