变更日志#

最好在此处查看: 这里。有关特定于实验性 Pallas API 的更改,请参阅Pallas 变更日志

JAX 遵循基于努力的版本控制;有关此以及 JAX 的 API 兼容性策略的讨论,请参阅API 兼容性。有关 Python 和 NumPy 版本支持策略,请参阅Python 和 NumPy 版本支持策略

未发布#

  • 更改

    • 最低 NumPy 版本现在是 1.25。NumPy 1.25 将保持最低支持版本,直到 2025 年 6 月。

  • 新功能

  • 弃用

    • jax.interpreters.xlaabstractifypytype_aval_mappings 现在已弃用,已被 jax.core 中具有相同名称的符号替换。

  • 删除

    • jax_enable_memories 标志已被删除,并且该标志的行为默认开启。

    • jax.lib.xla_client,删除了之前已弃用的 DeviceXlaRuntimeError 符号;请分别使用 jax.Devicejax.errors.JaxRuntimeError

jax 0.4.38 (2024年12月17日)#

  • 更改

    • jax.tree.flatten_with_pathjax.tree.map_with_path 被添加为相应的 tree_util 函数的快捷方式。

  • 弃用

    • 内部 jax.core 命名空间中的许多 API 已被弃用。大多数都是无操作,很少使用,或者可以用jax.extend.core中相同名称的 API 替换;有关这些半公开扩展的兼容性保证的信息,请参阅jax.extend的文档。

    • 已删除几个之前已弃用的 API,包括

      • jax.corecheck_eqncheck_typecheck_valid_jaxtypenon_negative_dim

      • jax.lib.xla_bridgexla_clientdefault_backend

      • 来自 jax.lib.xla_client_xlabfloat16

      • 来自 jax.numpyround_

  • 新功能

jax 0.4.37 (2024 年 12 月 9 日)#

这是 jax 0.4.36 的补丁版本。此版本仅发布了“jax”。

  • 错误修复

    • 修复了如果参数名为 fjit 会出错的错误 (#25329)。

    • 修复了如果用户为 flatten 和 flatten_with_path 注册具有不同辅助数据的 pytree 节点类,则在 jax.lax.while_loop() 中抛出 index out of range 错误的 bug。

    • 固定了新的 libtpu 版本 (0.0.6),该版本修复了 TPU v6e 上的编译器错误。

jax 0.4.36 (2024 年 12 月 5 日)#

  • 重大更改

    • 此版本引入了“无栈”,这是 JAX 跟踪机制的内部更改。我们将跟踪调度纯粹设为上下文的函数,而不是上下文和数据的函数。这使我们能够删除大量用于管理数据相关跟踪的机制:级别、子级别、post_process_callnew_base_maincustom_bind 等。此更改应该只会影响使用 JAX 内部的用户。

      如果您确实使用了 JAX 内部,则可能需要更新您的代码(请参阅 https://github.com/jax-ml/jax/commit/c36e1f7c1ad4782060cbc8e8c596d85dfb83986f,了解如何执行此操作的线索)。使用此方法的 JAX 库可能还存在版本偏差问题。如果您发现此更改破坏了您未使用 JAX 内部的代码,请尝试使用 config.jax_data_dependent_tracing_fallback 标志作为解决方法,如果您需要帮助更新代码,请提交错误。

    • 自 2024 年 7 月(JAX 版本 0.4.31)以来,jax.experimental.jax2tf.convert()native_serialization=False 或与 enable_xla=False 一起使用已被弃用。现在我们取消了对这些用例的支持。仍然支持使用原生序列化的 jax2tf

    • jax.interpreters.xla 中,xbxcxe 符号在 JAX v0.4.31 中被弃用后已删除。请改用 xb = jax.lib.xla_bridgexc = jax.lib.xla_clientxe = jax.lib.xla_extension

    • 已删除已弃用的模块 jax.experimental.export。它在 JAX v0.4.30 中被 jax.export 取代。有关迁移到新 API 的信息,请参阅 迁移指南

    • 在 v0.4.27 中被弃用后,jax.nn.softmax()jax.nn.log_softmax()initial 参数已被删除。

    • 现在,在类型化的 PRNG 键(即由 :func:jax.random.key 生成的键)上调用 np.asarray 会引发错误。以前,这会返回一个标量对象数组。

    • 已删除 jax.export 中以下已弃用的方法和函数

      • jax.export.DisabledSafetyCheck.shape_assertions:它已经没有效果。

      • jax.export.Exported.lowering_platforms:请使用 platforms

      • jax.export.Exported.mlir_module_serialization_version:请使用 calling_convention_version

      • jax.export.Exported.uses_shape_polymorphism:请使用 uses_global_constants

      • 用于 jax.export.export()lowering_platforms kwarg:请改用 platforms

    • 已删除 jax.export.symbolic_args_specs() 中的 kwargs symbolic_scopesymbolic_constraints。它们在 2024 年 6 月被弃用。请改用 scopeconstraints

    • 自 0.4.30 版本起被弃用的追踪器哈希现在会导致 TypeError

    • 重构:JAX 构建 CLI (build/build.py) 现在使用子命令结构并替换以前的 build.py 用法。运行 python build/build.py --help 获取更多详细信息。新子命令选项的简要概述

      • build:构建 JAX wheel 包。例如,python build/build.py build --wheels=jaxlib,jax-cuda-pjrt

      • requirements_update:更新 requirements_lock.txt 文件。

    • jax.scipy.linalg.toeplitz() 现在对多维输入进行隐式批处理。要恢复以前的行为,您可以在函数输入上调用 jax.numpy.ravel()

    • 现在,jax.scipy.special.gamma()jax.scipy.special.gammasgn() 对负整数输入返回 NaN,以匹配来自 https://github.com/scipy/scipy/pull/21827 的 SciPy 行为。

    • 在 v0.4.26 中被弃用后,jax.clear_backends 已被删除。

    • 我们从保证导出稳定性的自定义调用列表中删除了自定义调用“__gpu$xla.gpu.triton”。这是因为此自定义调用依赖于 Triton IR,而 Triton IR 不能保证稳定。如果您需要导出使用此自定义调用的代码,可以使用 disabled_checks 参数。请参阅 文档中了解更多详细信息。

  • 新功能

  • 错误修复

    • 修复了 LU 和 QR 分解的 GPU 实现会导致接近 int32 最大值的批处理大小的索引溢出的错误。有关更多详细信息,请参阅 #24843

  • 弃用

    • jax.lib.xla_extension.ArrayImpljax.lib.xla_client.ArrayImpl 已被弃用;请改用 jax.Array

    • jax.lib.xla_extension.XlaRuntimeError 已被弃用;请改用 jax.errors.JaxRuntimeError

jax 0.4.35 (2024 年 10 月 22 日)#

  • 重大更改

    • 现在,jax.numpy.isscalar() 对任何零维类数组对象返回 True。以前,它仅对具有弱 dtype 的零维类数组对象返回 True。

    • 自 2024 年 3 月 JAX 0.4.26 版本起,jax.experimental.host_callback 已被弃用。现在我们已将其移除。有关替代方案的讨论,请参阅 #20385

  • 更改

    • jax.lax.FftType 作为 FFT 操作枚举的公共名称引入。半公开 API jax.lib.xla_client.FftType 已被弃用。

    • TPU:JAX 现在从 libtpu 包而不是 libtpu-nightly 安装 TPU 支持。在接下来的几个版本中,JAX 将同时固定一个空的 libtpu-nightly 版本和 libtpu 版本,以方便过渡;该依赖项将在 2025 年第一季度移除。

  • 弃用

    • 半公开 API jax.lib.xla_client.PaddingType 已被弃用。没有 JAX API 使用此类型,因此没有替代品。

    • jax.pure_callback()jax.extend.ffi.ffi_call()vmap 下的默认行为已被弃用,同时这些函数的 vectorized 参数也被弃用。应使用 vmap_method 参数以获得更明确的行为。有关更多详细信息,请参阅 #23881 中的讨论。

    • 半公开 API jax.lib.xla_client.register_custom_call_target 已被弃用。请改用 JAX FFI。

    • 半公开 API jax.lib.xla_client.dtype_to_etypejax.lib.xla_client.opsjax.lib.xla_client.shape_from_pyvaljax.lib.xla_client.PrimitiveTypejax.lib.xla_client.Shapejax.lib.xla_client.XlaBuilderjax.lib.xla_client.XlaComputation 已被弃用。请改用 StableHLO。

jax 0.4.34(2024 年 10 月 4 日)#

  • 新功能

    • 此版本包含 Python 3.13 的 wheel 文件。尚不支持自由线程模式。

    • jax.errors.JaxRuntimeError 已作为之前私有的 XlaRuntimeError 类型的公共别名添加。

  • 重大更改

    • jax_pmap_no_rank_reduction 标志默认设置为 True

      • 对 pmap 结果执行 array[0] 现在会引入重塑(请改用 array[0:1])。

      • 每个分片的形状(可通过 jax_array.addressable_shards 或 jax_array.addressable_data(0) 访问)现在具有前导的 (1, …)。请更新直接访问分片的代码。现在,每个分片形状的秩与全局形状的秩相匹配,这与 jit 的行为相同。这样可以避免在将 pmap 的结果传递到 jit 时进行昂贵的重塑。

    • jax.experimental.host_callback 自 2024 年 3 月 JAX 0.4.26 版本起已被弃用。现在,我们将 --jax_host_callback_legacy 配置值的默认值设置为 True,这意味着如果您的代码使用了 jax.experimental.host_callback API,那么这些 API 调用将通过新的 jax.experimental.io_callback API 实现。如果这导致您的代码崩溃,您可以在有限的时间内将 --jax_host_callback_legacy 设置为 True。我们很快将删除该配置选项,因此您应该过渡到使用新的 JAX 回调 API。有关讨论,请参阅 #20385

  • 弃用

    • jax.numpy.trim_zeros() 中,非类数组参数或 ndim != 1 的类数组参数现在已被弃用,并且将来会导致错误。

    • 在 JAX v0.4.30 中弃用后,内部漂亮的打印工具 jax.core.pp_* 已被删除。

    • jax.lib.xla_client.Device 已被弃用;请改用 jax.Device

    • jax.lib.xla_client.XlaRuntimeError 已被弃用。请改用 jax.errors.JaxRuntimeError

  • 删除

    • jax.xla_computation 已被删除。自从在 0.4.30 JAX 版本中弃用以来已经过去了 3 个月。请使用 AOT API 来获得与 jax.xla_computation 相同的功能。

      • jax.xla_computation(fn)(*args, **kwargs) 可以替换为 jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')

      • 您还可以使用 jax.stages.Lowered.out_info 属性来获取输出信息(例如树结构、形状和 dtype)。

      • 对于跨后端降低,您可以使用 jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo') 来替换 jax.xla_computation(fn, backend='tpu')(*args, **kwargs)

    • jax.ShapeDtypeStruct 不再接受 named_shape 参数。该参数仅由 0.4.31 版本中删除的 xmap 使用。

    • jax.tree.map(f, None, non-None),之前发出 DeprecationWarning,现在在未来的 jax 版本中会引发错误。None 只是自身的树前缀。为了保留当前行为,您可以要求 jax.tree.mapNone 视为叶子值,方法是编写: jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)

    • jax.sharding.XLACompatibleSharding 已被删除。请使用 jax.sharding.Sharding

  • 错误修复

    • 修复了如果提供了非布尔输入并指定了 dtype=booljax.numpy.cumsum() 会产生不正确输出的错误。

    • 编辑 jax.numpy.ldexp() 的实现以获得正确的梯度。

jax 0.4.33(2024 年 9 月 16 日)#

这是在 jax 0.4.32 之上发布的补丁版本,修复了该版本中发现的两个错误。

在 JAX 0.4.32 固定的 libtpu 版本中发现了一个仅限 TPU 的数据损坏错误,该错误仅在同一作业中存在多个 TPU 片时才会出现,例如,在多个 v5e 片上进行训练时。此版本通过固定 libtpu 的固定版本来修复该问题。

此版本修复了 CPU 上 F64 tanh 的不准确结果 (#23590)。

jax 0.4.32(2024 年 9 月 11 日)#

注意:由于 TPU 上存在数据损坏错误,此版本已从 PyPi 中删除。有关更多详细信息,请参阅 0.4.33 版本说明。

  • 新功能

  • 更改

    • jax_enable_memories 标志默认设置为 True

    • jax.numpy 现在支持 Python Array API 标准的 v2023.12 版本。有关更多信息,请参阅 Python Array API 标准

    • 现在,在更多情况下,CPU 后端上的计算可能会异步调度。以前,非并行计算始终同步调度。您可以通过设置 jax.config.update('jax_cpu_enable_async_dispatch', False) 来恢复旧行为。

    • 添加了新的 jax.process_indices() 函数来替换在 JAX v0.2.13 中已弃用的 jax.host_ids() 函数。

    • 为了与 numpy.fabs 的行为保持一致,jax.numpy.fabs 已被修改为不再支持 complex dtypes

    • 如果 nodetype 是数据类,则 jax.tree_util.register_dataclass 现在会检查 data_fieldsmeta_fields 是否包含所有 init=True 的数据类字段,并且仅包含这些字段。

    • 现在有几个 jax.numpy 函数具有完整的 ufunc 接口,包括 addmultiplybitwise_andbitwise_orbitwise_xorlogical_andlogical_andlogical_and。在未来的版本中,我们计划将这些扩展到其他 ufuncs。

    • 添加了 jax.lax.optimization_barrier(),它允许用户阻止编译器优化,例如公共子表达式消除,并控制调度。

  • 重大更改

    • MHLO MLIR 方言(jax.extend.mlir.mhlo)已被删除。请改用 stablehlo 方言。

  • 弃用

    • 自从 JAX v0.4.27 弃用后,不再允许将复数输入到 jax.numpy.clip()jax.numpy.hypot() 中。

    • 以下 API 已被弃用:

      • jax.lib.xla_bridge.xla_client:请直接使用 jax.lib.xla_client

      • jax.lib.xla_bridge.get_backend:请使用 jax.extend.backend.get_backend()

      • jax.lib.xla_bridge.default_backend:请使用 jax.extend.backend.default_backend()

    • jax.experimental.array_api 模块已弃用,不再需要导入它来使用 Array API。jax.numpy 直接支持数组 API;有关更多信息,请参阅 Python 数组 API 标准

    • 内部实用程序 jax.core.check_eqnjax.core.check_typejax.core.check_valid_jaxtype 现在已弃用,并将在未来删除。

    • jax.numpy.round_ 已被弃用,原因是 NumPy 2.0 中删除了相应的 API。请改用 jax.numpy.round()

    • 将 DLPack capsule 传递给 jax.dlpack.from_dlpack() 已被弃用。jax.dlpack.from_dlpack() 的参数应该是来自另一个实现了 __dlpack__ 协议的框架的数组。

jaxlib 0.4.32(2024 年 9 月 11 日)#

注意:由于 TPU 上存在数据损坏错误,此版本已从 PyPi 中删除。有关更多详细信息,请参阅 0.4.33 版本说明。

  • 重大更改

    • 此版本的 jaxlib 切换到新版本的 CPU 后端,该后端应编译得更快并更好地利用并行性。如果由于此更改而遇到任何问题,您可以通过设置环境变量 XLA_FLAGS=--xla_cpu_use_thunk_runtime=false 暂时启用旧的 CPU 后端。如果您需要这样做,请提交 JAX 错误并提供重现说明。

    • 添加了 Hermetic CUDA 支持。Hermetic CUDA 使用特定的可下载 CUDA 版本,而不是用户本地安装的 CUDA。Bazel 将下载 CUDA、CUDNN 和 NCCL 发行版,然后在各种 Bazel 目标中使用 CUDA 库和工具作为依赖项。这使得 JAX 及其支持的 CUDA 版本的构建更具可重复性。

  • 更改

    • 添加了 SparseCore 分析。

      • JAX 现在支持在 TPUv5p 芯片上分析 SparseCore。这些跟踪将在 Tensorboard Profiler 的 TraceViewer 中查看。

jax 0.4.31(2024 年 7 月 29 日)#

  • 删除

    • xmap 已被删除。请使用 shard_map() 作为替代。

  • 更改

    • 最低 CuDNN 版本为 v9.1。在之前的版本中也是如此,但我们现在正式声明此版本约束。

    • 最低 Python 版本现在为 3.10。3.10 将在 2025 年 7 月之前保持为最低支持版本。

    • 最低 NumPy 版本现在为 1.24。NumPy 1.24 将在 2024 年 12 月之前保持为最低支持版本。

    • 最低 SciPy 版本现在为 1.10。SciPy 1.10 将在 2025 年 1 月之前保持为最低支持版本。

    • jax.numpy.ceil()jax.numpy.floor()jax.numpy.trunc() 现在返回与输入相同数据类型的输出,即不再将整数或布尔输入向上转换为浮点数。

    • libdevice.10.bc 不再与 CUDA wheels 捆绑在一起。它必须作为本地 CUDA 安装的一部分安装,或者通过 NVIDIA 的 CUDA pip wheels 安装。

    • jax.experimental.pallas.BlockSpec 现在期望在 index_map 之前传递 block_shape。旧的参数顺序已弃用,将在未来的版本中删除。

    • 更新了 GPU 设备的 repr,使其与 TPU/CPU 更一致。例如,cuda(id=0) 现在将是 CudaDevice(id=0)

    • jax.Array 添加了 device 属性和 to_device 方法,作为 JAX 的 Array API 支持的一部分。

  • 弃用

    • 删除了许多先前已弃用的与多态形状相关的内部 API。从 jax.core 中:删除了 canonicalize_shapedimension_as_valuedefinitely_equalsymbolic_equal_dim

    • HLO 降级规则不应再将单例 ir.Values 包装在元组中。相反,应返回未包装的单例 ir.Values。对包装值的支持将在未来版本的 JAX 中删除。

    • 现在已弃用带有 native_serialization=Falseenable_xla=Falsejax.experimental.jax2tf.convert(),此支持将在未来的版本中删除。自 JAX 0.4.16(2023 年 9 月)以来,原生序列化一直是默认设置。

    • 先前已弃用的函数 jax.random.shuffle 已被删除;请改用带有 independent=Truejax.random.permutation

jaxlib 0.4.31(2024 年 7 月 29 日)#

  • 错误修复

    • 修复了一个错误,该错误意味着 jit 调度快速路径错误地处理了 jit 的负 static_argnums。

    • 修复了一个错误,该错误意味着奇异矩阵批次的三角求解会产生无意义的有限值,而不是 inf 或 nan (#3589, #15429)。

jax 0.4.30(2024 年 6 月 18 日)#

  • 更改

    • JAX 支持 ml_dtypes >= 0.2。在 0.4.29 版本中,ml_dtypes 版本已升级到 0.4.0,但在此版本中已回滚,以便为 TensorFlow 和 JAX 的用户提供更多时间迁移到较新的 TensorFlow 版本。

    • jax.experimental.mesh_utils 现在可以为 TPU v5e 创建高效的网格。

    • jax 现在直接依赖于 jaxlib。此更改是由 CUDA 插件切换启用的:不再有多个 jaxlib 变体。您可以使用 pip install jax 安装仅 CPU 的 jax,无需任何额外组件。

    • 添加了用于导出和序列化 JAX 函数的 API。这曾经存在于 jax.experimental.export(正在弃用)中,现在将存在于 jax.export 中。请参阅文档

  • 弃用

    • 内部美观打印工具 jax.core.pp_* 已弃用,将在未来的版本中删除。

    • 跟踪器的哈希已弃用,并且将在未来的 JAX 版本中导致 TypeError。以前是这种情况,但在最近的几个 JAX 版本中存在一个无意的回归。

    • jax.experimental.export 已弃用。请改用 jax.export。请参阅迁移指南

    • 在大多数情况下,使用数组代替 dtype 传递参数已弃用;例如,对于数组 xyx.astype(y) 将会引发警告。要消除警告,请使用 x.astype(y.dtype)

    • jax.xla_computation 已弃用,将在未来的版本中移除。请使用 AOT API 来获得与 jax.xla_computation 相同的功能。

      • jax.xla_computation(fn)(*args, **kwargs) 可以替换为 jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')

      • 您还可以使用 jax.stages.Lowered.out_info 属性来获取输出信息(例如树结构、形状和 dtype)。

      • 对于跨后端降低,您可以使用 jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo') 来替换 jax.xla_computation(fn, backend='tpu')(*args, **kwargs)

jaxlib 0.4.30 (2024 年 6 月 18 日)#

  • 已删除对单体 CUDA jaxlib 的支持。您必须使用基于插件的安装方式 (pip install jax[cuda12]pip install jax[cuda12_local])。

jax 0.4.29 (2024 年 6 月 10 日)#

  • 更改

    • 我们预计这将是 JAX 和 jaxlib 支持单体 CUDA jaxlib 的最后一个版本。未来的版本将使用 CUDA 插件 jaxlib(例如,pip install jax[cuda12])。

    • JAX 现在需要 ml_dtypes 版本 0.4.0 或更高版本。

    • 移除了对旧版 jax.experimental.export API 用法的向后兼容性支持。不再可能使用 from jax.experimental.export import export,而是应该使用 from jax.experimental import export。此移除的功能自 0.4.24 版本起已被弃用。

    • jax.tree.all()jax.tree_util.tree_all() 添加了 is_leaf 参数。

  • 弃用

    • jax.sharding.XLACompatibleSharding 已弃用。请使用 jax.sharding.Sharding

    • jax.experimental.Exported.in_shardings 已重命名为 jax.experimental.Exported.in_shardings_hloout_shardings 也一样。旧名称将在 3 个月后删除。

    • 移除了一些之前已弃用的 API

      • 来自 jax.corenon_negative_dimDimSizeShape

      • 来自 jax.laxtie_in

      • 来自 jax.nnnormalize

      • 来自 jax.interpreters.xlabackend_specific_translationstranslationsregister_translationxla_destructureTranslationRuleTranslationContextXlaOp

    • jax.numpy.linalg.matrix_rank()tol 参数已被弃用,并将很快被删除。请改用 rtol

    • jax.numpy.linalg.pinv()rcond 参数已被弃用,并将很快被删除。请改用 rtol

    • 已移除已弃用的 jax.config 子模块。要配置 JAX,请使用 import jax,然后通过 jax.config 引用配置对象。

    • jax.random API 不再接受批处理的键,之前有些 API 意外地接受了。未来,我们建议在此类情况下显式使用 jax.vmap()

    • jax.scipy.special.beta() 中,xy 参数已重命名为 ab,以便与其他 beta API 保持一致。

  • 新功能

    • 添加了 jax.experimental.Exported.in_shardings_jax(),用于从存储在 Exported 对象中的 HloShardings 构建可与 JAX API 一起使用的分片。

jaxlib 0.4.29 (2024 年 6 月 10 日)#

  • 错误修复

    • 修复了 XLA 错误地对某些连接操作进行分片的 bug,该 bug 表现为累积归约的输出不正确 (#21403)。

    • 修复了 XLA:CPU 错误编译某些 matmul 融合的问题 (https://github.com/openxla/xla/pull/13301)。

    • 修复了 GPU 上的编译器崩溃问题 (https://github.com/jax-ml/jax/issues/21396)。

  • 弃用

    • jax.tree.map(f, None, non-None) 现在会发出 DeprecationWarning,并且将在未来版本的 jax 中引发错误。None 只是其自身的树前缀。要保留当前行为,可以要求 jax.tree.mapNone 视为叶值,方法是编写:jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)

jax 0.4.28 (2024 年 5 月 9 日)#

  • 错误修复

    • 回滚了对 make_jaxpr 的更改,该更改破坏了 Equinox (#21116)。

  • 弃用和移除

  • 更改

    • 此版本的最低 jaxlib 版本为 0.4.27。

jaxlib 0.4.28 (2024 年 5 月 9 日)#

  • 错误修复

    • 修复了 Python 3.10 或更早版本中 Array 和 JIT Python 对象的类型名称中的内存损坏错误。

    • 修复了 CUDA 12.4 下的警告 '+ptx84' is not a recognized feature for this target

    • 修复了 CPU 上编译速度缓慢的问题。

  • 更改

    • Windows 版本现在使用 Clang 而不是 MSVC 构建。

jax 0.4.27 (2024 年 5 月 7 日)#

  • 新功能

    • 根据 array API 2023 标准(即将被 NumPy 采用)添加了 jax.numpy.unstack()jax.numpy.cumulative_sum()

    • 添加了一个新的配置选项 jax_cpu_collectives_implementation,用于选择 CPU 后端使用的跨进程集体操作的实现。可用的选项有 'none' (默认)、'gloo''mpi' (需要 jaxlib 0.4.26)。如果设置为 'none',则禁用跨进程集体操作。

  • 更改

    • jax.pure_callback(), jax.experimental.io_callback()jax.debug.callback() 现在使用 jax.Array 而不是 np.ndarray。你可以通过在将参数传递给回调之前使用 jax.tree.map(np.asarray, args) 转换参数来恢复旧的行为。

    • complex_arr.astype(bool) 现在遵循与 NumPy 相同的语义,当 complex_arr 等于 0 + 0j 时返回 False,否则返回 True。

    • core.Token 现在是一个非平凡的类,它包装了一个 jax.Array。它可以被创建并穿梭于计算中以建立依赖关系。单例对象 core.token 已被移除,用户现在应该创建和使用新的 core.Token 对象来代替。

    • 在 GPU 上,Threefry PRNG 的实现默认不再降低为内核调用。这种选择可以在编译时增加开销的情况下提高运行时内存使用率。可以通过 jax.config.update('jax_threefry_gpu_kernel_lowering', True) 来恢复产生内核调用的先前行为。如果新的默认行为导致问题,请提交 bug。否则,我们计划在未来的版本中删除此标志。

  • 弃用 & 移除

    • Pallas 现在专门使用 XLA 在 GPU 上编译内核。通过 Triton Python API 的旧的降低过程已被移除,并且 JAX_TRITON_COMPILE_VIA_XLA 环境变量不再起任何作用。

    • jax.numpy.clip() 有一个新的参数签名:aa_mina_max 已被弃用,取而代之的是 x(仅限位置参数)、minmax (#20550)。

    • JAX 数组的 device() 方法已被移除,该方法自 JAX v0.4.21 起已被弃用。请改用 arr.devices()

    • jax.nn.softmax()jax.nn.log_softmax()initial 参数已被弃用;现在支持 softmax 的空输入,无需设置此参数。

    • jax.jit() 中,传递无效的 static_argnumsstatic_argnames 现在会导致错误,而不是警告。

    • 现在最低 jaxlib 版本为 0.4.23。

    • 当传递复数值输入给 jax.numpy.hypot() 函数时,它现在会发出弃用警告。在弃用完成后,这将引发错误。

    • 现在,jax.numpy.nonzero()jax.numpy.where() 和相关函数的标量参数会引发错误,这与 NumPy 中的类似更改一致。

    • 配置选项 jax_cpu_enable_gloo_collectives 已被弃用。请改用 jax.config.update('jax_cpu_collectives_implementation', 'gloo')

    • 在 JAX v0.4.22 中被弃用后,jax.Array.device_bufferjax.Array.device_buffers 方法已被移除。请改用 jax.Array.addressable_shardsjax.Array.addressable_data()

    • 在 JAX v0.4.21 中关键字被弃用后,jax.numpy.whereconditionxy 参数现在仅限位置参数。

    • jax.lax.linalg 中函数的非数组参数现在必须通过关键字指定。之前,这会引发 DeprecationWarning。

    • 现在,jax.numpy 中的几个 API 中需要类数组参数,包括 apply_along_axis()apply_over_axes()inner()outer()cross()kron()lexsort()

  • 错误修复

    • copy=True 时,jax.numpy.astype() 现在将始终返回一个副本。以前,当输出数组与输入数组具有相同的数据类型时,不会创建副本。这可能会导致一些内存使用量的增加。默认值设置为 copy=False,以保留向后兼容性。

jaxlib 0.4.27 (2024 年 5 月 7 日)#

jax 0.4.26 (2024 年 4 月 3 日)#

  • 新功能

  • 更改

    • 复数值的 jax.numpy.geomspace() 现在选择与 NumPy 2.0 一致的对数螺旋分支。

    • lax.rng_bit_generator 的行为,以及 'rbg''unsafe_rbg' PRNG 实现,在 jax.vmap已更改,以便映射键仅导致从批处理中的第一个键生成随机数。

    • 文档现在使用 jax.random.key 来构造 PRNG 键数组,而不是 jax.random.PRNGKey

  • 弃用 & 移除

    • jax.tree_map() 已被弃用;请改用 jax.tree.map,或者为了向后兼容旧的 JAX 版本,请使用 jax.tree_util.tree_map()

    • jax.clear_backends() 已被弃用,因为它不一定像其名称所暗示的那样执行,并且可能会导致意外的后果,例如,它不会销毁现有的后端并释放相应的拥有资源。如果你只想清理编译缓存,请使用 jax.clear_caches()。为了向后兼容,或者你确实需要切换/重新初始化默认后端,请使用 jax.extend.backend.clear_backends()

    • jax.experimental.maps 模块和 jax.experimental.maps.xmap 已被弃用。请使用 jax.experimental.shard_map 或带有 spmd_axis_name 参数的 jax.vmap 来表达 SPMD 设备并行计算。

    • jax.experimental.host_callback 模块已被弃用。请改用新的 JAX 外部回调。添加了 JAX_HOST_CALLBACK_LEGACY 标志,以帮助过渡到新的回调。有关讨论,请参阅 #20385

    • 现在,将无法转换为 JAX 数组的参数传递给 jax.numpy.array_equal()jax.numpy.array_equiv() 会导致异常。

    • 已移除已弃用的标志 jax_parallel_functions_output_gda。该标志早已被弃用,并且没有任何作用;它的使用是空操作。

    • 之前已弃用的导入 jax.interpreters.ad.configjax.interpreters.ad.source_info_util 现在已被移除。请改用 jax.configjax.extend.source_info_util

    • JAX 导出不再支持旧的序列化版本。版本 9 自 2023 年 10 月 27 日起已受支持,并自 2024 年 2 月 1 日起成为默认版本。请参阅版本描述。此更改可能会破坏设置了低于 9 的特定 JAX 序列化版本的客户端。

jaxlib 0.4.26 (2024 年 4 月 3 日)#

  • 更改

    • JAX 现在仅支持 CUDA 12.1 或更高版本。已放弃对 CUDA 11.8 的支持。

    • JAX 现在支持 NumPy 2.0。

jax 0.4.25 (2024 年 2 月 26 日)#

  • 新功能

  • 更改

    • Pallas 现在使用 XLA 而不是 Triton Python API 来编译 Triton 内核。您可以通过将 JAX_TRITON_COMPILE_VIA_XLA 环境变量设置为 "0" 来恢复旧的行为。

    • 在 v0.4.24 中删除的 jax.interpreters.xla 中的几个已弃用的 API 已在 v0.4.25 中重新添加,包括 backend_specific_translationstranslationsregister_translationxla_destructureTranslationRuleTranslationContextXLAOp。 这些仍然被认为是已弃用的,并且将来在有更好的替代品时将再次被删除。 有关讨论,请参阅 #19816

  • 弃用 & 移除

    • jax.numpy.linalg.solve() 现在针对 b.ndim > 1 的批量 1D 求解显示弃用警告。将来,这些将被视为批量 2D 求解。

    • 现在,将非标量数组转换为 Python 标量会引发错误,而与数组的大小无关。以前,在大小为 1 的非标量数组的情况下会发出弃用警告。这遵循 NumPy 中的类似弃用。

    • 在标准的 3 个月弃用周期之后,以前已弃用的配置 API 已被移除(请参阅 API 兼容性)。这些包括

      • jax.config.config 对象和

      • jax.configdefine_*_stateDEFINE_* 方法。

    • 通过 import jax.config 导入 jax.config 子模块已被弃用。要配置 JAX,请使用 import jax,然后通过 jax.config 引用配置对象。

    • 最低 jaxlib 版本现在为 0.4.20。

jaxlib 0.4.25 (2024 年 2 月 26 日)#

jax 0.4.24 (2024 年 2 月 6 日)#

  • 更改

    • JAX 降低到 StableHLO 不再依赖于物理设备。如果您的原语在降低规则中包装了 custom_partitioning 或 JAX 回调,即传递给 mlir.register_loweringrule 参数的函数,那么请将您的原语添加到 jax._src.dispatch.prim_requires_devices_during_lowering 集合中。这是必需的,因为 custom_partitioning 和 JAX 回调需要物理设备才能在降低期间创建 Sharding。在我们可以在没有物理设备的情况下创建 Sharding 之前,这是一个临时状态。

    • jax.numpy.argsort()jax.numpy.sort() 现在支持 stabledescending 参数。

    • 对形状多态性(在 jax.experimental.jax2tfjax.experimental.export 中使用)的处理进行了一些更改

      • 符号表达式的更清晰的漂亮打印 (#19227)

      • 添加了指定维度变量的符号约束的功能。这使得形状多态性更具表现力,并提供了一种解决不等式推理局限性的方法。请参阅 https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints。

      • 随着符号约束的添加 (#19235),我们现在认为来自不同范围的维度变量是不同的,即使它们具有相同的名称。来自不同范围的符号表达式无法交互,例如,在算术运算中。范围由 jax.experimental.jax2tf.convert()jax.experimental.export.symbolic_shape()jax.experimental.export.symbolic_args_specs() 引入。可以使用 e.scope 读取符号表达式 e 的范围,并将其传递到上述函数中,以指导它们在给定范围内构造符号表达式。请参阅 https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints。

      • 简化且更快的相等性比较,如果它们的差的标准化形式简化为 0,则我们认为两个符号维度相等 (#19231;请注意,这可能会导致用户可见的行为更改)

      • 改进了不确定不等式比较的错误消息 (#19235)。

      • 已弃用 core.non_negative_dim API(最近引入),并引入了 core.max_dimcore.min_dim (#18953) 以表达符号维度的 maxmin。您可以使用 core.max_dim(d, 0) 而不是 core.non_negative_dim(d)

      • 弃用 shape_poly.is_poly_dim,改用 export.is_symbolic_dim (#19282)。

      • 弃用 export.args_specs,改用 export.symbolic_args_specs ({jax-issue}#19283`)。

      • 弃用 shape_poly.PolyShapejax2tf.PolyShape,请使用字符串来表示多态形状规范 (#19284)。

      • JAX 默认的本机序列化版本现在为 9。这与 jax.experimental.jax2tfjax.experimental.export 相关。请参阅版本号的描述

    • 重构了 jax.experimental.export 的 API。您现在应该使用 from jax.experimental import export,而不是使用 from jax.experimental.export import export。旧的导入方式将在 3 个月的弃用期内继续有效。

    • 添加了 jax.scipy.stats.sem()

    • 带有 return_inverse = Truejax.numpy.unique() 返回的逆索引会按照与 NumPy 2.0 中 numpy.unique() 类似的更改,被重塑为输入数据的维度。

    • jax.numpy.sign() 现在对于非零复数输入返回 x / abs(x)。这与 NumPy 2.0 版本中 numpy.sign() 的行为一致。

    • 带有 return_sign=Truejax.scipy.special.logsumexp() 现在对于复数的符号使用 NumPy 2.0 的约定 x / abs(x)。这与 SciPy v1.13 中 scipy.special.logsumexp() 的行为一致。

    • JAX 现在支持导入和导出布尔型的 DLPack 类型。以前,布尔值无法导入,并且导出为整数。

  • 弃用 & 移除

    • 许多先前已弃用的函数已被删除,遵循标准的 3 个月以上的弃用周期(请参阅 API 兼容性)。这包括:

      • 来自 jax.coreTracerArrayConversionErrorTracerIntegerConversionErrorUnexpectedTracerErroras_hashable_functioncollectionsdtypeslumapnamedtuplepartialpprefsafe_zipsafe_mapsource_info_utiltotal_orderingtraceback_utiltuple_deletetuple_insertzip

      • 来自 jax.laxdtypesitertoolsnaryopnaryop_dtype_rulestandard_abstract_evalstandard_naryopstandard_primitivestandard_unopunopunop_dtype_rule

      • jax.linear_util 子模块及其所有内容。

      • jax.prng 子模块及其所有内容。

      • 来自 jax.randomPRNGKeyArrayKeyArraydefault_prng_implthreefry_2x32threefry2x32_keythreefry2x32_prbg_keyunsafe_rbg_key

      • 来自 jax.tree_utilregister_keypathsAttributeKeyPathEntryGetItemKeyPathEntry

      • 来自 jax.interpreters.xlabackend_specific_translationstranslationsregister_translationxla_destructureTranslationRuleTranslationContextaxis_groupsShapedArrayConcreteArrayAxisEnvbackend_compileXLAOp

      • 来自 jax.numpyNINFNZEROPZEROrow_stackissubsctypetrapzin1d

      • 来自 jax.scipy.linalgtriltriu

    • 先前已弃用的方法 PRNGKeyArray.unsafe_raw_array 已被删除。请改用 jax.random.key_data()

    • bool(empty_array) 现在会引发错误,而不是返回 False。之前这会引发弃用警告,并且与 NumPy 中的类似更改保持一致。

    • 对 mhlo MLIR 方言的支持已被弃用。JAX 不再使用 mhlo 方言,而是使用 stablehlo。将来会删除引用 “mhlo” 的 API。请改用 “stablehlo” 方言。

    • jax.random:直接将批处理的键传递给随机数生成函数(例如 bits()gamma() 等)已被弃用,并会发出 FutureWarning。请使用 jax.vmap 进行显式批处理。

    • jax.lax.tie_in() 已被弃用:自 JAX v0.2.0 以来,它一直是一个空操作。

jaxlib 0.4.24 (2024 年 2 月 6 日)#

  • 更改

    • JAX 现在支持 CUDA 12.3 和 CUDA 11.8。已删除对 CUDA 12.2 的支持。

    • cost_analysis 现在可以与交叉编译的 Compiled 对象一起使用(即,当使用带有拓扑对象的 .lower().compile() 时,例如,从非 TPU 计算机编译用于云 TPU)。

    • 添加了 CUDA 数组接口 导入支持(需要 jax 0.4.25)。

jax 0.4.23 (2023 年 12 月 13 日)#

jaxlib 0.4.23 (2023 年 12 月 13 日)#

  • 修复了在编译期间导致 GPU 编译器输出大量日志的错误。

jax 0.4.22 (2023 年 12 月 13 日)#

  • 弃用

    • JAX 数组的 device_bufferdevice_buffers 属性已被弃用。显式缓冲区已替换为更灵活的数组分片接口,但可以通过以下方式恢复先前的输出:

      • arr.device_buffer 变为 arr.addressable_data(0)

      • arr.device_buffers 变为 [x.data for x in arr.addressable_shards]

jaxlib 0.4.22 (2023 年 12 月 13 日)#

jax 0.4.21 (2023 年 12 月 4 日)#

  • 新功能

  • 更改

    • 最低 jaxlib 版本现在为 0.4.19。

    • 现在使用 clang 而不是 gcc 构建发布的 wheels。

    • 强制在调用 jax.distributed.initialize() 之前未初始化设备后端。

    • 自动化云 TPU 环境中 jax.distributed.initialize() 的参数。

  • 弃用

    • 已从 jax.scipy.linalg.solve() 中删除先前已弃用的 sym_pos 参数。请改用 assume_a='pos'

    • None 直接或在列表或元组中传递给 jax.array()jax.asarray() 已被弃用,现在会引发 FutureWarning。它目前被转换为 NaN,将来会引发 TypeError

    • 为了与 numpy.where 保持一致,使用关键字参数将 conditionxy 参数传递给 jax.numpy.where 的做法已被弃用。

    • 将无法转换为 JAX 数组的参数传递给 jax.numpy.array_equal()jax.numpy.array_equiv() 的做法已被弃用,现在会引发 DeprecationWaning 警告。目前,这些函数返回 False,未来将会引发异常。

    • JAX 数组的 device() 方法已被弃用。根据上下文,它可能被以下方法之一取代:

      • jax.Array.devices() 返回数组使用的所有设备的集合。

      • jax.Array.sharding 提供数组使用的分片配置。

jaxlib 0.4.21 (2023 年 12 月 4 日)#

  • 更改

    • 为了准备添加分布式 CPU 支持,JAX 现在将 CPU 设备与 GPU 和 TPU 设备同等对待,也就是说:

      • jax.devices() 包括分布式作业中存在的所有设备,即使是那些不在当前进程本地的设备。jax.local_devices() 仍然只包括当前进程本地的设备。因此,如果对 jax.devices() 的更改破坏了您的代码,您很可能需要改用 jax.local_devices()

      • CPU 设备现在在分布式作业中接收一个全局唯一的 ID 号;以前,CPU 设备会接收一个进程本地的 ID 号。

      • 每个 CPU 设备的 process_index 现在将与同一进程中的任何 GPU 或 TPU 设备匹配;以前,CPU 设备的 process_index 始终为 0。

    • 在 NVIDIA GPU 上,JAX 现在优先使用 Jacobi SVD 求解器处理最大 1024x1024 的矩阵。Jacobi 求解器似乎比非 Jacobi 版本更快。

  • 错误修复

    • 修复了当将具有非有限值的数组传递给非对称特征分解时出现的错误/挂起问题 (#18226)。现在,具有非有限值的数组会产生充满 NaN 的输出数组。

jax 0.4.20 (2023 年 11 月 2 日)#

jaxlib 0.4.20 (2023 年 11 月 2 日)#

  • 错误修复

    • 修复了 E4M3 和 E5M2 float8 类型之间的一些类型混淆问题。

jax 0.4.19 (2023 年 10 月 19 日)#

  • 新功能

    • 添加了 jax.typing.DTypeLike,可用于注释可转换为 JAX dtypes 的对象。

    • 添加了 jax.numpy.fill_diagonal

  • 更改

    • JAX 现在需要 SciPy 1.9 或更高版本。

  • 错误修复

    • 在多控制器分布式 JAX 程序中,只有进程 0 会写入持久编译缓存条目。这修复了如果缓存放置在 GCS 等网络文件系统上时出现的写入争用问题。

    • 在确定已安装的库版本是否至少与构建 JAX 时使用的版本一样新时,cusolver 和 cufft 的版本检查不再考虑补丁版本。

jaxlib 0.4.19 (2023 年 10 月 19 日)#

  • 更改

    • 如果安装了 pip 安装的 NVIDIA CUDA 库(nvidia-… 包),jaxlib 现在将始终优先选择它们,而不是任何其他 CUDA 安装,包括在 LD_LIBRARY_PATH 中指定的安装。如果这导致问题并且目的是使用系统安装的 CUDA,则解决方法是删除 pip 安装的 CUDA 库包。

jax 0.4.18 (2023 年 10 月 6 日)#

jaxlib 0.4.18 (2023 年 10 月 6 日)#

  • 更改

    • CUDA jaxlibs 现在依赖于用户安装兼容的 NCCL 版本。如果使用推荐的 cuda12_pip 安装,则应自动安装 NCCL。目前,需要 NCCL 2.16 或更高版本。

    • 我们现在提供 Linux aarch64 wheels,包括有和没有 NVIDIA GPU 支持的两种版本。

    • jax.Array.item() 现在支持可选的索引参数。

  • 弃用

    • jax.lax 中的一些内部实用程序和无意导出的内容已被弃用,将在未来的版本中删除。

      • jax.lax.dtypes:请改用 jax.dtypes

      • jax.lax.itertools:请改用 itertools

      • naryopnaryop_dtype_rulestandard_abstract_evalstandard_naryopstandard_primitivestandard_unopunopunop_dtype_rule 是内部实用程序,现在已被弃用,没有替代方案。

  • 错误修复

    • 修复了由于 smem 导致的 Cloud TPU 编译 OOM 回归问题。

jax 0.4.17 (2023 年 10 月 3 日)#

  • 新功能

  • 弃用

    • 删除了已弃用的模块 jax.abstract_arrays 及其所有内容。

    • jax.random 中的命名键构造函数已被弃用。请将 impl 参数传递给 jax.random.PRNGKey()jax.random.key()

      • random.threefry2x32_key(seed) 变为 random.PRNGKey(seed, impl='threefry2x32')

      • random.rbg_key(seed) 变为 random.PRNGKey(seed, impl='rbg')

      • random.unsafe_rbg_key(seed) 变为 random.PRNGKey(seed, impl='unsafe_rbg')

  • 更改

    • CUDA:JAX 现在验证它找到的 CUDA 库至少与构建 JAX 时使用的 CUDA 库一样新。如果找到较旧的库,JAX 会引发异常,因为这比神秘的故障和崩溃更可取。

    • 删除了“未找到 GPU/TPU”警告。相反,如果在 Linux 上找到 NVIDIA GPU 或 Google TPU 但未使用,并且未指定 --jax_platforms,则会发出警告。

    • jax.scipy.stats.mode() 现在如果对大小为 0 的轴取众数,则会返回计数 0,这与 SciPy 1.11 中 scipy.stats.mode 的行为相匹配。

    • 大多数 jax.numpy 函数和属性现在都有完全定义的类型存根。以前,许多这些都被像 mypypytype 这样的静态类型检查器视为 Any

jaxlib 0.4.17 (2023 年 10 月 3 日)#

  • 更改

    • 此版本中添加了 Python 3.12 wheels。

    • CUDA 12 wheels 现在需要 CUDA 12.2 或更高版本以及 cuDNN 8.9.4 或更高版本。

  • 错误修复

    • 修复了初始化 JAX CPU 后端时 ABSL 产生的日志垃圾邮件。

jax 0.4.16 (2023 年 9 月 18 日)#

  • 更改

    • 添加了 jax.numpy.ufunc,以及 jax.numpy.frompyfunc(),它可以将任何标量值函数转换为类似 numpy.ufunc() 的对象,具有诸如 outer()reduce()accumulate()at()reduceat() 等方法 (#17054)。

    • 添加了 jax.scipy.integrate.trapezoid()

    • 当不在 IPython 下运行时:当引发异常时,JAX 现在会从回溯中过滤掉其所有内部帧。(没有以前出现的“未过滤的堆栈跟踪”。)这应该会产生更加友好的回溯。有关示例,请参见 此处。可以通过设置 JAX_TRACEBACK_FILTERING=remove_frames (对于两个单独的未过滤/过滤的回溯,这是以前的行为)或 JAX_TRACEBACK_FILTERING=off (对于一个未过滤的回溯)来更改此行为。

    • jax2tf 默认序列化版本现在是 7,它引入了新的形状 安全断言

    • 传递给 jax.sharding.Mesh 的设备应该是可哈希的。这特别适用于模拟设备或用户创建的设备。jax.devices() 已经是可哈希的。

  • 重大更改

    • jax2tf 现在默认使用原生序列化。有关详细信息以及覆盖默认机制的方法,请参阅jax2tf 文档

    • 选项 --jax_coordination_service 已被移除。它现在始终为 True

    • jax.jaxpr_util 已从公共 JAX 命名空间中移除。

    • JAX_USE_PJRT_C_API_ON_TPU 不再起作用(即,它始终默认为 true)。

    • 在 2021 年 12 月引入的向后兼容标志 --jax_host_callback_ad_transforms 已被移除。

  • 弃用

    • 根据 NumPy NEP-52,一些 jax.numpy API 已被弃用。

      • jax.numpy.NINF 已被弃用。请改用 -jax.numpy.inf

      • jax.numpy.PZERO 已被弃用。请改用 0.0

      • jax.numpy.NZERO 已被弃用。请改用 -0.0

      • jax.numpy.issubsctype(x, t) 已被弃用。请使用 jax.numpy.issubdtype(x.dtype, t)

      • jax.numpy.row_stack 已被弃用。请改用 jax.numpy.vstack

      • jax.numpy.in1d 已被弃用。请改用 jax.numpy.isin

      • jax.numpy.trapz 已被弃用。请改用 jax.scipy.integrate.trapezoid

    • 根据 SciPy,jax.scipy.linalg.triljax.scipy.linalg.triu 已被弃用。请改用 jax.numpy.triljax.numpy.triu

    • jax.lax.prod 在 JAX v0.4.11 中被弃用后已移除。请改用内置的 math.prod

    • 一些与为自定义 JAX 原语定义 HLO 降级规则相关的 jax.interpreters.xla 导出已被弃用。自定义原语应改用 jax.interpreters.mlir 中的 StableHLO 降级实用程序定义。

    • 以下之前已弃用的函数在三个月的弃用期后已被移除

      • jax.abstract_arrays.ShapedArray:请使用 jax.core.ShapedArray

      • jax.abstract_arrays.raise_to_shaped:请使用 jax.core.raise_to_shaped

      • jax.numpy.alltrue:请使用 jax.numpy.all

      • jax.numpy.sometrue:请使用 jax.numpy.any

      • jax.numpy.product:请使用 jax.numpy.prod

      • jax.numpy.cumproduct:请使用 jax.numpy.cumprod

  • 弃用/移除

    • 内部子模块 jax.prng 现在已弃用。其内容可在 jax.extend.random 中找到。

    • 内部子模块路径 jax.linear_util 已被弃用。请改用 jax.extend.linear_util(属于 jax.extend:扩展模块

    • jax.random.PRNGKeyArrayjax.random.KeyArray 已被弃用。请使用 jax.Array 进行类型注释,并使用 jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key) 进行类型化 prng 密钥的运行时检测。

    • 方法 PRNGKeyArray.unsafe_raw_array 已被弃用。请改用 jax.random.key_data()

    • jax.experimental.pjit.with_sharding_constraint 已被弃用。请改用 jax.lax.with_sharding_constraint

    • 内部实用程序 jax.core.is_opaque_dtypejax.core.has_opaque_dtype 已被移除。不透明 dtypes 已重命名为扩展 dtypes;请改用 jnp.issubdtype(dtype, jax.dtypes.extended)(自 jax v0.4.14 起可用)。

    • 实用程序 jax.interpreters.xla.register_collective_primitive 已被移除。此实用程序在最近的 JAX 版本中没有执行任何有用的操作,并且可以安全地删除对其的调用。

    • 内部子模块路径 jax.linear_util 已被弃用。请改用 jax.extend.linear_util(属于 jax.extend:扩展模块

jaxlib 0.4.16 (2023 年 9 月 18 日)#

  • 更改

    • 通过实验性的 jax sparse API 进行的稀疏 CSR 矩阵乘法不再在 NVIDIA GPU 上使用确定性算法。此更改是为了提高与 CUDA 12.2.1 的兼容性而进行的。

  • 错误修复

    • 修复了 Windows 上由于与无序部分和 IMAGE_REL_AMD64_ADDR32NB 重定位相关的致命 LLVM 错误导致的崩溃问题 (https://github.com/openxla/xla/commit/cb732a921f0c4184995cbed82394931011d12bd4)。

jax 0.4.14 (2023 年 7 月 27 日)#

  • 更改

    • jax.jitdonate_argnames 作为参数。它的语义与 static_argnames 类似。如果未提供 donate_argnums 和 donate_argnames,则不会捐赠任何参数。如果未提供 donate_argnums 但提供了 donate_argnames,反之亦然,JAX 会使用 inspect.signature(fun) 来查找与 donate_argnames 对应的任何位置参数(或反之亦然)。如果同时提供了 donate_argnums 和 donate_argnames,则不会使用 inspect.signature,并且只会捐赠 donate_argnums 或 donate_argnames 中列出的实际参数。

    • jax.random.gamma() 已被重构为更高效的算法,具有更强大的端点行为 (#16779)。这意味着,对于给定的 key,在 JAX v0.4.13 和 v0.4.14 之间,gamma 和相关的采样器(包括 jax.random.ball()jax.random.beta()jax.random.chisquare()jax.random.dirichlet()jax.random.generalized_normal()jax.random.loggamma()jax.random.t()) 返回的值序列将会发生变化。

  • 删除

    • in_axis_resourcesout_axis_resources 已从 pjit 中删除,因为它们被弃用已超过 3 个月。请使用 in_shardingsout_shardings 作为替代。这是一个安全且微不足道的名称替换。它不会更改当前任何 pjit 语义,也不会破坏任何代码。您仍然可以将 PartitionSpecs 传递给 in_shardings 和 out_shardings。

  • 弃用

    • 根据 https://jax.ac.cn/en/latest/deprecation.html,已放弃对 Python 3.8 的支持。

    • 根据 https://jax.ac.cn/en/latest/deprecation.html,JAX 现在需要 NumPy 1.22 或更高版本。

    • 在 JAX 版本 0.4.7 中被弃用后,不再支持按位置将可选参数传递给 jax.numpy.ndarray.at()。例如,请使用 x.at[i].get(indices_are_sorted=True),而不是 x.at[i].get(True)

    • 以下 jax.Array 方法在 JAX v0.4.5 中被弃用后已被移除

    • 以下 API 在之前弃用后已被移除

      • jax.ad:请使用 jax.interpreters.ad

      • jax.curry:请使用 curry = lambda f: partial(partial, f)

      • jax.partial_eval:请使用 jax.interpreters.partial_eval

      • jax.pxla:请使用 jax.interpreters.pxla

      • jax.xla:请使用 jax.interpreters.xla

      • jax.ShapedArray: 请使用 jax.core.ShapedArray

      • jax.interpreters.pxla.device_put: 请使用 jax.device_put()

      • jax.interpreters.pxla.make_sharded_device_array: 请使用 jax.make_array_from_single_device_arrays()

      • jax.interpreters.pxla.ShardedDeviceArray: 请使用 jax.Array

      • jax.numpy.DeviceArray: 请使用 jax.Array

      • jax.stages.Compiled.compiler_ir: 请使用 jax.stages.Compiled.as_text()

  • 重大更改

    • JAX 现在要求 ml_dtypes 版本为 0.2.0 或更高版本。

    • 为了修复一个极端情况,如果第二个和第三个参数是可调用的,则对具有五个参数的 jax.lax.cond() 的调用将始终解析为“通用操作数” cond 行为(如文档所述),即使其他操作数也是可调用的。请参阅 #16413

    • 已删除已弃用的配置选项 jax_arrayjax_jit_pjit_api_merge,它们没有任何作用。这些选项在许多版本中默认设置为 true。

  • 新功能

    • JAX 现在支持一个配置标志 –jax_serialization_version 和一个 JAX_SERIALIZATION_VERSION 环境变量来控制序列化版本 (#16746)。

    • 如果序列化版本至少为 7,则 jax2tf 在存在形状多态性的情况下,现在会生成检查某些形状约束的代码。请参阅 https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism。

jaxlib 0.4.14 (2023 年 7 月 27 日)#

  • 弃用

    • 根据 https://jax.ac.cn/en/latest/deprecation.html,已放弃对 Python 3.8 的支持。

jax 0.4.13 (2023 年 6 月 22 日)#

  • 更改

    • jax.jit 现在允许将 None 传递给 in_shardingsout_shardings。 语义如下:

      • 对于 in_shardings,JAX 将其标记为复制,但此行为将来可能会更改。

      • 对于 out_shardings,我们将依赖 XLA GSPMD 分区器来确定输出分片。

    • jax.experimental.pjit.pjit 也允许将 None 传递给 in_shardingsout_shardings。 语义如下:

      • 如果提供网格上下文管理器,则 JAX 可以自由选择它想要的任何分片。

        • 对于 in_shardings,JAX 将其标记为复制,但此行为将来可能会更改。

        • 对于 out_shardings,我们将依赖 XLA GSPMD 分区器来确定输出分片。

      • 如果提供了网格上下文管理器,则 None 将意味着该值将在网格的所有设备上复制。

    • Executable.cost_analysis() 在 Cloud TPU 上工作

    • 如果正在使用非白名单的 jaxlib 插件,则会添加警告。

    • 添加了 jax.tree_util.tree_leaves_with_path

    • None 不是 jax.experimental.multihost_utils.host_local_array_to_global_arrayjax.experimental.multihost_utils.global_array_to_host_local_array 的有效输入。 如果您想要复制输入,请使用 jax.sharding.PartitionSpec()

  • 错误修复

    • 修复了 CUDA 12 版本中不正确的 wheel 名称 (#16362);正确的 wheel 名称为 cudnn89 而不是 cudnn88

  • 弃用

    • 为了支持新的 native_serializaation_disabled_checksjax.experimental.jax2tf.convert()native_serialization_strict_checks 参数已弃用(#16347)。

jaxlib 0.4.13 (2023 年 6 月 22 日)#

  • 更改

    • 将 Windows 仅 CPU wheel 添加到 jaxlib Pypi 版本。

  • 错误修复

    • __cuda_array_interface__ 在之前的 jaxlib 版本中被破坏,现在已修复 (#16440)。

    • 现在默认在 NVIDIA GPU 上启用并发 CUDA 内核跟踪。

jax 0.4.12 (2023 年 6 月 8 日)#

  • 更改

  • 弃用

    • jax.abstract_arrays 及其内容现在已弃用。 请参阅 jax.core 中的相关功能。

    • jax.numpy.alltrue: 请使用 jax.numpy.all。 这遵循了 NumPy 1.25.0 版本中对 numpy.alltrue 的弃用。

    • jax.numpy.sometrue: 请使用 jax.numpy.any。 这遵循了 NumPy 1.25.0 版本中对 numpy.sometrue 的弃用。

    • jax.numpy.product: 请使用 jax.numpy.prod。 这遵循了 NumPy 1.25.0 版本中对 numpy.product 的弃用。

    • jax.numpy.cumproduct: 请使用 jax.numpy.cumprod。 这遵循了 NumPy 1.25.0 版本中对 numpy.cumproduct 的弃用。

    • jax.sharding.OpShardingSharding 已被删除,因为它已被弃用 3 个月了。

jaxlib 0.4.12 (2023 年 6 月 8 日)#

  • 更改

    • 包含用于 Hopper (SM 版本 9.0+) GPU 的 PTX/SASS。 以前版本的 jaxlib 应该可以在 Hopper 上工作,但在第一次执行 JAX 操作时会有很长的 JIT 编译延迟。

  • 错误修复

    • 修复了在 Python 3.11 下 JAX 生成的 Python 回溯中不正确的源行信息。

    • 修复了在 JAX 生成的 Python 回溯中打印帧的局部变量时发生的崩溃 (#16027)。

jax 0.4.11 (2023 年 5 月 31 日)#

  • 弃用

    • 根据 API 兼容性 策略,在 3 个月的弃用期后,以下 API 已被删除

      • jax.experimental.PartitionSpec: 请使用 jax.sharding.PartitionSpec

      • jax.experimental.maps.Mesh: 请使用 jax.sharding.Mesh

      • jax.experimental.pjit.NamedSharding: 请使用 jax.sharding.NamedSharding

      • jax.experimental.pjit.PartitionSpec: 请使用 jax.sharding.PartitionSpec

      • jax.experimental.pjit.FROM_GDA。 相反,请将分片的 jax.Array 对象作为输入传递,并删除 pjit 的可选 in_shardings 参数。

      • jax.interpreters.pxla.PartitionSpec: 请使用 jax.sharding.PartitionSpec

      • jax.interpreters.pxla.Mesh: 请使用 jax.sharding.Mesh

      • jax.interpreters.xla.Buffer: 请使用 jax.Array

      • jax.interpreters.xla.Device: 请使用 jax.Device

      • jax.interpreters.xla.DeviceArray: 请使用 jax.Array

      • jax.interpreters.xla.device_put: 请使用 jax.device_put

      • jax.interpreters.xla.xla_call_p: 请使用 jax.experimental.pjit.pjit_p

      • 删除了 with_sharding_constraintaxis_resources 参数。 请改用 shardings

jaxlib 0.4.11 (2023 年 5 月 31 日)#

  • 更改

    • Device 添加了 memory_stats() 方法。 如果支持,这将返回一个包含字符串统计名称和整数值的字典,例如 "bytes_in_use",或者如果平台不支持内存统计信息,则返回 None。 返回的确切统计信息可能因平台而异。 当前仅在 Cloud TPU 上实现。

    • 在 CPU 设备上重新添加了对 Python 缓冲区协议 (memoryview) 的支持。

jax 0.4.10 (2023 年 5 月 11 日)#

jaxlib 0.4.10 (2023 年 5 月 11 日)#

  • 更改

    • 修复了 'apple-m1' is not a recognized processor for this target (ignoring processor) 问题,该问题阻止了以前的版本在 Mac M1 上运行。

jax 0.4.9 (2023 年 5 月 9 日)#

  • 更改

    • 已删除标志 experimental_cpp_jit、experimental_cpp_pjit 和 experimental_cpp_pmap。 它们现在始终处于启用状态。

    • 提高了 TPU 上奇异值分解 (SVD) 的准确性(需要 jaxlib 0.4.9)。

  • 弃用

    • jax.experimental.gda_serialization 已被弃用,并已重命名为 jax.experimental.array_serialization。请更改您的导入语句以使用 jax.experimental.array_serialization

    • pjit 的 in_axis_resourcesout_axis_resources 参数已被弃用。请分别使用 in_shardingsout_shardings

    • 函数 jax.numpy.msort 已被移除。它自 JAX v0.4.1 起已被弃用。请改用 jnp.sort(a, axis=0)

    • in_partsout_parts 参数已从 jax.xla_computation 中移除,因为它们仅与 sharded_jit 一起使用,而 sharded_jit 早已不再使用。

    • instantiate_const_outputs 参数已从 jax.xla_computation 中移除,因为它已经很长时间未使用了。

jaxlib 0.4.9(2023 年 5 月 9 日)#

jax 0.4.8(2023 年 3 月 29 日)#

  • 重大更改

    • Cloud TPU 运行时的一个主要组件已升级。这使得在 Cloud TPU 上启用以下新功能

      jax.experimental.host_callback() 在新的运行时组件下不再支持 Cloud TPU。如果新的 jax.debug API 不足以满足您的使用情况,请在 JAX 问题跟踪器上提交问题。

      旧的运行时组件将通过设置环境变量 JAX_USE_PJRT_C_API_ON_TPU=false 在至少未来三个月内可用。如果您发现出于任何原因需要禁用新的运行时,请在 JAX 问题跟踪器上告知我们。

  • 更改

    • 最低 jaxlib 版本已从 0.4.6 提升至 0.4.7。

  • 弃用

    • 已放弃 CUDA 11.4 的支持。JAX GPU 轮子仅支持 CUDA 11.8 和 CUDA 12。如果 jaxlib 从源代码构建,则较旧的 CUDA 版本可能有效。

    • pmap 的 global_arg_shapes 参数仅与 sharded_jit 一起使用,并且已从 pmap 中删除。请迁移到 pjit 并从 pmap 中删除 global_arg_shapes。

jax 0.4.7(2023 年 3 月 27 日)#

  • 更改

    • 根据 https://jax.ac.cn/en/latest/jax_array_migration.html#jax-array-migration,jax.config.jax_array 无法再禁用。

    • jax.config.jax_jit_pjit_api_merge 无法再禁用。

    • jax.experimental.jax2tf.convert() 现在支持 native_serialization 参数,以使用 JAX 的原生降级到 StableHLO,从而为整个 JAX 函数获取 StableHLO 模块,而不是将每个 JAX 基元降级为 TensorFlow 操作。这简化了内部结构,并增加了您序列化的内容与 JAX 原生语义匹配的信心。请参阅文档。作为此更改的一部分,配置标志 --jax2tf_default_experimental_native_lowering 已重命名为 --jax2tf_native_serialization

    • JAX 现在依赖于 ml_dtypes,其中包含诸如 bfloat16 之类的 NumPy 类型的定义。这些定义以前是 JAX 的内部定义,但已拆分为单独的包,以便与其他项目共享。

    • JAX 现在需要 NumPy 1.21 或更高版本以及 SciPy 1.7 或更高版本。

  • 弃用

    • 类型 jax.numpy.DeviceArray 已被弃用。请改用 jax.Array,它是一个别名。

    • 类型 jax.interpreters.pxla.ShardedDeviceArray 已被弃用。请改用 jax.Array

    • 通过位置传递其他参数给 jax.numpy.ndarray.at() 已被弃用。例如,请使用 x.at[i].get(indices_are_sorted=True),而不是 x.at[i].get(True)

    • jax.interpreters.xla.device_put 已被弃用。请使用 jax.device_put

    • jax.interpreters.pxla.device_put 已被弃用。请使用 jax.device_put

    • jax.experimental.pjit.FROM_GDA 已被弃用。请传入分片的 jax.Arrays 作为输入,并删除 pjit 的 in_shardings 参数,因为它是可选的。

jaxlib 0.4.7(2023 年 3 月 27 日)#

更改

  • jaxlib 现在依赖于 ml_dtypes,其中包含诸如 bfloat16 之类的 NumPy 类型的定义。这些定义以前是 JAX 的内部定义,但已拆分为单独的包,以便与其他项目共享。

jax 0.4.6(2023 年 3 月 9 日)#

  • 更改

    • jax.tree_util 现在包含一组 API,允许用户为其自定义 pytree 节点定义键。这包括

      • tree_flatten_with_path,它会展平树并返回每个叶子以及它们的键路径。

      • tree_map_with_path,它可以映射一个以键路径作为参数的函数。

      • register_pytree_with_keys,用于注册自定义 pytree 节点中键路径和叶子的外观。

      • keystr,它会漂亮地打印键路径。

    • jax2tf.call_tf() 有一个新的参数 output_shape_dtype(默认为 None),可用于声明结果的输出形状和类型。这使 jax2tf.call_tf() 能够在存在形状多态性的情况下工作。(#14734)。

  • 弃用

    • jax.tree_util 中的旧键路径 API 已被弃用,将在 2023 年 3 月 10 日起的 3 个月内移除。

jaxlib 0.4.6(2023 年 3 月 9 日)#

jax 0.4.5(2023 年 3 月 2 日)#

  • 弃用

    • jax.sharding.OpShardingSharding 已重命名为 jax.sharding.GSPMDShardingjax.sharding.OpShardingSharding 将在 2023 年 2 月 17 日起的 3 个月内移除。

    • 以下 jax.Array 方法已弃用,将在 2023 年 2 月 23 日起的 3 个月内移除

jax 0.4.4(2023 年 2 月 16 日)#

  • 更改

    • jitpjit 的实现已合并。合并 jit 和 pjit 更改了 JAX 的内部结构,但不影响 JAX 的公共 API。之前,jit 是 final 风格的原语。Final 风格意味着 jaxpr 的创建会尽可能延迟,并且转换会相互堆叠。通过 jit-pjit 实现合并,jit 变成了一个 initial 风格的原语,这意味着我们会尽早跟踪到 jaxpr。有关更多信息,请参阅autodidax 中的此部分。转向 initial 风格应该简化 JAX 的内部结构,并使动态形状等功能的开发更加容易。您只能通过环境变量禁用它,例如 os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'。必须通过环境变量禁用合并,因为它会在导入时影响 JAX,因此需要在导入 jax 之前禁用它。

    • with_sharding_constraintaxis_resources 参数已弃用。请改用 shardings。如果您将 axis_resources 用作 arg,则无需更改。如果您将其用作 kwarg,请改用 shardingsaxis_resources 将在 2023 年 2 月 13 日之后的 3 个月内删除。

    • 添加了 jax.typing 模块,其中包含用于 JAX 函数类型注释的工具。

    • 以下名称已被弃用

      • jax.xla.Devicejax.interpreters.xla.Device:请使用 jax.Device

      • jax.experimental.maps.Mesh。请改用 jax.sharding.Mesh

      • jax.experimental.pjit.NamedSharding: 请使用 jax.sharding.NamedSharding

      • jax.experimental.pjit.PartitionSpec: 请使用 jax.sharding.PartitionSpec

      • jax.interpreters.pxla.Mesh:请使用 jax.sharding.Mesh

      • jax.interpreters.pxla.PartitionSpec: 请使用 jax.sharding.PartitionSpec

  • 重大更改

    • 类似于 jax.numpy.sum 的归约函数的 initial 参数现在必须是标量,与相应的 NumPy API 一致。之前将输出广播到非标量 initial 值的行为是无意的实现细节 (#14446)。

jaxlib 0.4.4 (2023 年 2 月 16 日)#

  • 重大更改

    • 默认的 jaxlib 构建版本已删除对 NVIDIA Kepler 系列 GPU 的支持。如果需要 Kepler 支持,仍然可以从源代码构建具有 Kepler 支持的 jaxlib(通过 build.py--cuda_compute_capabilities=sm_35 选项),但请注意,CUDA 12 已完全放弃对 Kepler GPU 的支持。

jax 0.4.3 (2023 年 2 月 8 日)#

jaxlib 0.4.3 (2023 年 2 月 8 日)#

  • jax.Array 现在具有非阻塞的 is_ready() 方法,如果数组已准备就绪,则返回 True(另请参阅 jax.block_until_ready())。

jax 0.4.2 (2023 年 1 月 24 日)#

  • 重大更改

    • 删除了 jax.experimental.callback

    • 在存在 jax2tf 形状多态性的情况下,具有维度的操作已进行泛化,以便在更多场景中工作,方法是将符号维度转换为 JAX 数组。当结果用作形状值时,涉及符号维度和 np.ndarray 的操作现在可能会引发错误 (#14106)。

    • jaxpr 对象现在会在设置属性时引发错误,以避免有问题的突变 (#14102)

  • 更改

    • jax2tf.call_tf() 有一个新参数 has_side_effects(默认值 True),可用于声明实例是否可以被 JAX 优化(例如死代码消除)删除或复制 (#13980)。

    • 为 jax2tf 形状多态性添加了对 floordiv 和 mod 的更多支持。以前,在存在符号维度的情况下,某些除法运算会导致错误 (#14108)。

jaxlib 0.4.2 (2023 年 1 月 24 日)#

  • 更改

    • 设置 JAX_USE_PJRT_C_API_ON_TPU=1 以启用新的 Cloud TPU 运行时,该运行时具有自动设备内存碎片整理功能。

jax 0.4.1 (2022 年 12 月 13 日)#

  • 更改

    • 根据 JAX 的Python 和 NumPy 版本支持策略,已删除对 Python 3.7 的支持。

    • 我们引入了 jax.Array,它是一个统一的数组类型,它包含 JAX 中的 DeviceArrayShardedDeviceArrayGlobalDeviceArray 类型。jax.Array 类型有助于使并行性成为 JAX 的核心功能,简化和统一 JAX 内部结构,并使我们能够统一 jitpjitjax.Array 已在 JAX 0.4 中默认启用,并且对 pjit API 进行了一些重大更改。jax.Array 迁移指南可以帮助您将代码库迁移到 jax.Array。您还可以查看分布式数组和自动并行化教程,以了解新概念。

    • PartitionSpecMesh 现在已退出实验阶段。新的 API 端点是 jax.sharding.PartitionSpecjax.sharding.Meshjax.experimental.maps.Meshjax.experimental.PartitionSpec 已弃用,将在 3 个月内删除。

    • with_sharding_constraint 的新公共端点是 jax.lax.with_sharding_constraint

    • 如果将 ABSL 标志与 jax.config 一起使用,则在最初从 ABSL 标志填充 JAX 配置选项后,不再读取或写入 ABSL 标志值。此更改提高了读取 jax.config 选项的性能,这些选项在 JAX 中被广泛使用。

    • jax2tf.call_tf 函数现在使用与嵌入式 JAX 计算所使用的平台相同的第一个 TF 设备进行 TF 降低。以前,它使用 JAX 默认后端的第 0 个设备。

    • 许多 jax.numpy 函数现在将其参数标记为仅位置参数,与 NumPy 匹配。

    • 根据 numpy 1.24 中 np.msort 的弃用,jnp.msort 现在已被弃用。它将在未来的版本中删除,符合API 兼容性策略。可以用 jnp.sort(a, axis=0) 替换它。

jaxlib 0.4.1 (2022 年 12 月 13 日)#

  • 更改

    • 根据 JAX 的Python 和 NumPy 版本支持策略,已删除对 Python 3.7 的支持。

    • XLA_PYTHON_CLIENT_MEM_FRACTION=.XX 的行为已更改为分配总 GPU 内存的 XX%,而不是以前使用当前可用的 GPU 内存来计算预分配的行为。有关更多详细信息,请参阅GPU 内存分配

    • 已删除已弃用的方法 .block_host_until_ready()。请改用 .block_until_ready()

jax 0.4.0 (2022 年 12 月 12 日)#

  • 此版本已撤回。

jaxlib 0.4.0 (2022 年 12 月 12 日)#

  • 此版本已撤回。

jax 0.3.25 (2022 年 11 月 15 日)#

jaxlib 0.3.25 (2022 年 11 月 15 日)#

  • 更改

    • 添加了对 CPU 和 GPU 上三对角线简化的支持。

    • 添加了对 CPU 上上Hessenberg 简化的支持。

  • 错误修复

    • 修复了一个错误,该错误导致 JAX 捕获的回溯中的帧在 Python 3.10+ 下被错误地映射到源代码行。

jax 0.3.24 (2022 年 11 月 4 日)#

  • 更改

    • JAX 的导入速度应该更快。我们现在延迟导入 scipy,这占用了 JAX 导入时间的大部分。

    • 设置环境变量 JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS=$N 可以用来限制写入持久缓存的缓存条目数量。默认情况下,编译时间为 1 秒或更长的计算将被缓存。

    • 如果未指定顺序,pmap 在 TPU 上使用的默认设备顺序现在与单进程作业的 jax.devices() 匹配。之前,这两个顺序不同,这可能会导致不必要的复制或内存不足错误。要求顺序一致可以简化问题。

  • 重大更改

  • 弃用

    • jax.sharding.MeshPspecSharding 已重命名为 jax.sharding.NamedShardingjax.sharding.MeshPspecSharding 名称将在 3 个月后删除。

jaxlib 0.3.24 (2022 年 11 月 4 日)#

  • 更改

    • 缓冲区捐赠现在可以在 CPU 上工作。这可能会破坏在 CPU 上标记了缓冲区进行捐赠,但依赖于捐赠未实现的的代码。

jax 0.3.23 (2022 年 10 月 12 日)#

  • 更改

    • 更新 Colab TPU 驱动程序版本以支持新的 jaxlib 版本。

jax 0.3.22 (2022 年 10 月 11 日)#

  • 更改

    • 在 TPU 初始化中添加 JAX_PLATFORMS=tpu,cpu 作为默认设置,以便在 TPU 无法初始化时,JAX 会引发错误,而不是回退到 CPU。设置 JAX_PLATFORMS='' 以覆盖此行为并自动选择可用的后端(原始默认设置),或者设置 JAX_PLATFORMS=cpu 以始终使用 CPU,无论 TPU 是否可用。

  • 弃用

    • 在 JAX v0.3.8 中已弃用的多个测试实用程序现在已从 jax.test_util 中删除。

jaxlib 0.3.22 (2022 年 10 月 11 日)#

jax 0.3.21 (2022 年 9 月 30 日)#

  • GitHub 提交.

  • 更改

    • 持久编译缓存现在会在出错时发出警告,而不是引发异常 (#12582),因此如果缓存出现问题,程序执行可以继续。设置 JAX_RAISE_PERSISTENT_CACHE_ERRORS=true 以恢复此行为。

jax 0.3.20 (2022 年 9 月 28 日)#

  • 错误修复

    • 添加了上一个版本中缺少的 .pyi 文件 (#12536)。

    • 修复了 jax 0.3.19 与其所固定的 libtpu 版本之间的不兼容问题 (#12550)。需要 jaxlib 0.3.20。

    • 修复了 setup.py 注释中不正确的 pip url (#12528)。

jaxlib 0.3.20 (2022 年 9 月 28 日)#

  • GitHub 提交.

  • 错误修复

    • 修复了在分布式作业中通过 jax_cuda_visible_devices 限制可见 CUDA 设备的支持。此功能对于 GPU 上的 JAX/SLURM 集成是必需的 (#12533)。

jax 0.3.19 (2022 年 9 月 27 日)#

jax 0.3.18 (2022 年 9 月 26 日)#

  • GitHub 提交.

  • 更改

    • 提前降低和编译功能(在 #7733 中跟踪)是稳定且公开的。请参阅 概述jax.stages 的 API 文档。

    • 引入了 jax.Array,旨在用于 JAX 中数组类型的 isinstance 检查和类型注释。请注意,这包括对 isinstance 如何处理 jax.numpy.ndarray 的一些细微更改,因为 jax.numpy.ndarray 现在是 jax.Array 的简单别名。

  • 重大更改

    • jax._src 不再导入到公共的 jax 命名空间中。这可能会破坏使用 JAX 内部结构的用户的代码。

    • jax.soft_pmap 已删除。请改用 pjitxmapjax.soft_pmap 没有文档记录。如果它有文档记录,则会提供一个弃用期。

jax 0.3.17 (2022 年 8 月 31 日)#

  • GitHub 提交.

  • 错误修复

    • 修复了 lax.pow 的梯度在指数为零时的边界情况问题 (#12041)

  • 重大更改

    • jax.checkpoint(),也称为 jax.remat(),不再支持 concrete 选项,遵循之前版本的弃用;请参阅 JEP 11830

  • 更改

    • 添加了 jax.pure_callback(),它允许从编译后的函数(例如用 jax.jitjax.pmap 修饰的函数)回调到纯 Python 函数。

  • 弃用

    • 已弃用的 DeviceArray.tile() 方法已删除。请使用 jax.numpy.tile() (#11944)。

    • DeviceArray.to_py() 已被弃用。请改用 np.asarray(x)

jax 0.3.16#

jax 0.3.15 (2022 年 7 月 22 日)#

jaxlib 0.3.15 (2022 年 7 月 22 日)#

jax 0.3.14 (2022 年 6 月 27 日)#

  • GitHub 提交.

  • 重大更改

    • jax.experimental.compilation_cache.initialize_cache() 不再支持 max_cache_size_  bytes,并且不会将其作为输入。

    • 当平台初始化失败时,JAX_PLATFORMS 现在会引发异常。

  • 更改

    • 修复了与 NumPy 1.23 的兼容性问题。

    • jax.numpy.linalg.slogdet() 现在接受可选的 method 参数,该参数允许在基于 LU 分解的实现和基于 QR 分解的实现之间进行选择。

    • jax.numpy.linalg.qr() 现在支持 mode="raw"

    • 现在,当在 jax 数组上使用 picklecopy.copycopy.deepcopy 时,它们具有更完整的支持 (#10659)。特别是:

      • 以前,当在 DeviceArray 上使用 pickledeepcopy 时,它们会返回 np.ndarray 对象;现在返回 DeviceArray 对象。对于 deepcopy,复制的数组与原始数组在同一设备上。对于 pickle,反序列化的数组将位于默认设备上。

      • 在函数转换(即,跟踪的代码)中,deepcopycopy 以前是空操作。现在,它们使用与 DeviceArray.copy() 相同的机制。

      • 在跟踪的数组上调用 pickle 现在会导致显式的 ConcretizationTypeError

    • 奇异值分解 (SVD) 和对称/厄米特特征分解的实现在 TPU 上应该会快得多,特别是对于 1000x1000 或更大的矩阵。两者现在都使用谱分治算法进行特征分解 (QDWH-eig)。

    • jax.numpy.ldexp() 不再静默地将所有输入提升为 float64,而是将大小为 int32 或更小的整数输入提升为 float32 (#10921)。

    • jax.profiler.start_trace()jax.profiler.start_trace() 添加 create_perfetto_link 选项。使用时,分析器将生成一个指向 Perfetto UI 的链接以查看跟踪。

    • 更改了 jax.profiler.start_server(...)() 的语义,以全局存储 keepalive,而不是要求用户保持对其的引用。

    • 添加了 jax.random.generalized_normal()

    • 添加了 jax.random.ball()

    • 添加了 jax.default_device()

    • 添加了 python -m jax.collect_profile 脚本,以手动捕获程序跟踪,作为 TensorBoard UI 的替代方案。

    • 添加了 jax.named_scope 上下文管理器,它将分析器元数据添加到 Python 程序中(类似于 jax.named_call)。

    • 在散布更新操作(即 :attr:jax.numpy.ndarray.at)中,不安全的隐式数据类型转换已被弃用,现在会产生 FutureWarning。在未来的版本中,这将变为错误。不安全的隐式转换的一个例子是 jnp.zeros(4, dtype=int).at[0].set(1.5),其中 1.5 之前会被静默截断为 1

    • jax.experimental.compilation_cache.initialize_cache() 现在支持将 gcs 存储桶路径作为输入。

    • 添加了 jax.scipy.stats.gennorm()

    • 当系数具有前导零时,strip_zeros=False 时的 jax.numpy.roots() 现在表现更好 (#11215)。

jaxlib 0.3.14 (2022 年 6 月 27 日)#

  • GitHub 提交.

    • x86-64 Mac 版本现在要求 Mac OS 10.14 (Mojave) 或更高版本。Mac OS 10.14 于 2018 年发布,因此这应该不是一个非常苛刻的要求。

    • 捆绑的 NCCL 版本已更新至 2.12.12,修复了一些死锁问题。

    • Python flatbuffers 包不再是 jaxlib 的依赖项。

jax 0.3.13 (2022 年 5 月 16 日)#

jax 0.3.12 (2022 年 5 月 15 日)#

jax 0.3.11 (2022 年 5 月 15 日)#

  • GitHub 提交.

  • 更改

    • jax.lax.eigh() 现在接受一个可选的 sort_eigenvalues 参数,允许用户选择不启用 TPU 上的特征值排序。

  • 弃用

    • jax.lax.linalg 中函数的非数组参数现在被标记为仅限关键字。作为向后兼容步骤,以位置方式传递仅限关键字的参数会产生警告,但在未来的 JAX 版本中,以位置方式传递仅限关键字的参数将会失败。但是,大多数用户应首选使用 jax.numpy.linalg

    • 作为 scipy API 的 JAX 扩展的 jax.scipy.linalg.polar_unitary() 已被弃用。请改用 jax.scipy.linalg.polar()

jax 0.3.10 (2022 年 5 月 3 日)#

jaxlib 0.3.10 (2022 年 5 月 3 日)#

  • GitHub 提交.

  • 更改

    • TF 提交修复了 MHLO 规范化器中的一个问题,该问题导致常量折叠对于某些程序花费很长时间或崩溃。

jax 0.3.9 (2022 年 5 月 2 日)#

  • GitHub 提交.

  • 更改

    • 增加了对 GlobalDeviceArray 的完全异步检查点支持。

jax 0.3.8 (2022 年 4 月 29 日)#

  • GitHub 提交.

  • 更改

    • TPU 上的 jax.numpy.linalg.svd() 使用 qdwh-svd 求解器。

    • TPU 上的 jax.numpy.linalg.cond() 现在接受复数输入。

    • TPU 上的 jax.numpy.linalg.pinv() 现在接受复数输入。

    • TPU 上的 jax.numpy.linalg.matrix_rank() 现在接受复数输入。

    • 添加了 jax.scipy.cluster.vq.vq()

    • jax.experimental.maps.mesh 已被删除。请使用 jax.experimental.maps.Mesh。有关更多信息,请参阅 https://jax.ac.cn/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh。

    • 为了与 scipy.linalg.qr 的行为相匹配 (#10452),当 mode='r' 时,jax.scipy.linalg.qr() 现在返回长度为 1 的元组,而不是原始数组。

    • jax.numpy.take_along_axis() 现在采用一个可选的 mode 参数,该参数指定越界索引的行为。默认情况下,越界索引将返回无效值(例如,NaN)。在以前的 JAX 版本中,无效索引会被钳制到范围内。可以通过传递 mode="clip" 来恢复以前的行为。

    • jax.numpy.take() 现在默认为 mode="fill",对于越界索引返回无效值(例如,NaN)。

    • 散布操作(例如 x.at[...].set(...))现在具有 "drop" 语义。这对散布操作本身没有影响,但这意味着当微分时,散布的梯度将对越界索引产生零余切。以前,对于梯度,越界索引会被钳制到范围内,这在数学上是不正确的。

    • 如果 jax.numpy.take_along_axis() 的索引不是整数类型,则现在会引发 TypeError,这与 numpy.take_along_axis() 的行为相匹配。以前,非整数索引会被静默转换为整数。

    • 如果 jax.numpy.ravel_multi_index()dims 参数不是整数类型,则现在会引发 TypeError,这与 numpy.ravel_multi_index() 的行为相匹配。以前,非整数 dims 会被静默转换为整数。

    • 如果 jax.numpy.split()axis 参数不是整数类型,则现在会引发 TypeError,这与 numpy.split() 的行为相匹配。以前,非整数 axis 会被静默转换为整数。

    • 如果 jax.numpy.indices() 的维度不是整数类型,则现在会引发 TypeError,这与 numpy.indices() 的行为相匹配。以前,非整数维度会被静默转换为整数。

    • 如果 jax.numpy.diag()k 参数不是整数类型,则现在会引发 TypeError,这与 numpy.diag() 的行为相匹配。以前,非整数 k 会被静默转换为整数。

    • 添加了 jax.random.orthogonal()

  • 弃用

    • jax.test_util 中许多函数和对象现已弃用,导入时会发出警告。这包括 cases_from_listcheck_closecheck_eqdevice_under_testformat_shape_dtype_stringrand_uniformskip_on_deviceswith_configxla_bridge_default_tolerance (#10389)。这些以及之前已弃用的 JaxTestCaseJaxTestLoaderBufferDonationTestCase 将在未来的 JAX 版本中移除。大多数这些实用程序可以用对标准 Python 和 NumPy 测试实用程序的调用来替代,这些实用程序可以在例如 unittestabsl.testingnumpy.testing 等中找到。JAX 特定的功能(如设备检查)可以通过使用公共 API(如 jax.devices())来替代。许多已弃用的实用程序仍将存在于 jax._src.test_util 中,但这些不是公共 API,因此可能会在未来的版本中更改或删除,恕不另行通知。

jax 0.3.7 (2022 年 4 月 15 日)#

jaxlib 0.3.7 (2022 年 4 月 15 日)#

  • 更改

    • Linux wheels 现在按照 manylinux2014 标准构建,而不是 manylinux2010

jax 0.3.6 (2022 年 4 月 12 日)#

  • GitHub 提交.

  • 更改

    • 升级了 libtpu wheel 到修复了初始化 TPU pod 时挂起的版本。修复了 #10218

  • 弃用

    • jax.experimental.loops 正在被弃用。有关替代 API,请参阅 #10278

jax 0.3.5 (2022 年 4 月 7 日)#

jaxlib 0.3.5 (2022 年 4 月 7 日)#

  • 错误修复

    • 修复了双精度复数到实数 IRFFT 在 GPU 上会改变其输入缓冲区的错误 (#9946)。

    • 修复了复数散布的错误常量折叠 (#10159)

jax 0.3.4 (2022 年 3 月 18 日)#

jax 0.3.3 (2022 年 3 月 17 日)#

jax 0.3.2 (2022 年 3 月 16 日)#

  • GitHub 提交.

  • 更改

    • 已移除在 0.2.22 中弃用的函数 jax.ops.index_updatejax.ops.index_add。请改用 JAX 数组上的 .at 属性,例如 x.at[idx].set(y)

    • jax.experimental.ann.approx_*_k 移到 jax.lax 中。这些函数是 jax.lax.top_k 的优化替代方案。

    • jax.numpy.broadcast_arrays()jax.numpy.broadcast_to() 现在要求输入为标量或类似数组,如果传递列表则会失败 (部分属于 #7737)。

    • 标准的 jax[tpu] 安装现在可以与 Cloud TPU v4 VM 一起使用。

    • pjit 现在可以在 CPU 上运行(除了之前支持的 TPU 和 GPU)。

jaxlib 0.3.2 (2022 年 3 月 16 日)#

  • 更改

    • XlaComputation.as_hlo_text() 现在支持通过传递布尔标志 print_large_constants=True 来打印大型常量。

  • 弃用

    • JAX 数组上的 .block_host_until_ready() 方法已被弃用。请改用 .block_until_ready()

jax 0.3.1 (2022 年 2 月 18 日)#

jax 0.3.0 (2022 年 2 月 10 日)#

jaxlib 0.3.0 (2022 年 2 月 10 日)#

  • 更改

    • 现在需要 Bazel 5.0.0 来构建 jaxlib。

    • jaxlib 版本已提升至 0.3.0。请参阅设计文档了解详细说明。

jax 0.2.28 (2022 年 2 月 1 日)#

  • GitHub 提交.

    • 如果没有传递 dialect=jax.jit(f).lower(...).compiler_ir() 现在默认使用 MHLO 方言。

    • jax.jit(f).lower(...).compiler_ir(dialect='mhlo') 现在返回 MLIR ir.Module 对象,而不是其字符串表示形式。

jaxlib 0.1.76 (2022 年 1 月 27 日)#

  • 新功能

    • 包含适用于 NVidia 计算能力 8.0 GPU(例如 A100)的预编译 SASS。移除计算能力 6.1 的预编译 SASS,以避免增加计算能力的数量:计算能力为 6.1 的 GPU 可以使用 6.0 SASS。

    • 在 jaxlib 0.1.76 中,JAX 默认使用 MHLO MLIR 方言作为其主要目标编译器 IR。

  • 重大更改

    • 已根据弃用策略,放弃对 NumPy 1.18 的支持。请升级到受支持的 NumPy 版本。

  • 错误修复

    • 修复了由不同路径构建的看似相同的 pytreedef 对象无法比较为相等的问题 (#9066)。

    • JAX jit 缓存要求两个静态参数具有相同的类型才能命中缓存 (#9311)。

jax 0.2.27 (2022 年 1 月 18 日)#

  • GitHub 提交.

  • 重大更改

    • 已根据弃用策略,放弃对 NumPy 1.18 的支持。请升级到受支持的 NumPy 版本。

    • host_callback 原语已简化,不再对 hcb.id_tap 和 id_print 进行特殊的自动微分处理。从现在开始,只提取原始值。可以通过设置 JAX_HOST_CALLBACK_AD_TRANSFORMS 环境变量或 --jax_host_callback_ad_transforms 标志来获得旧的行为(在有限的时间内)。此外,还添加了关于如何使用 JAX 自定义 AD API 实现旧行为的文档 (#8678)。

    • 排序现在匹配 NumPy 对于 0.0NaN 的行为,无论位表示形式如何。特别是,0.0-0.0 现在被视为等价,而之前 -0.0 被视为小于 0.0。此外,所有 NaN 表示形式现在被视为等价,并排序到数组的末尾。之前,负的 NaN 值被排序到数组的前面,并且具有不同内部位表示形式的 NaN 值不被视为等价,而是根据这些位模式进行排序 (#9178)。

    • jax.numpy.unique() 现在以与 NumPy 版本 1.21 及更新版本中的 np.unique 相同的方式处理 NaN 值:最多只有一个 NaN 值会出现在去重后的输出中 (#9184)。

  • 错误修复

    • host_callback 现在支持 ad_checkpoint.checkpoint (#8907)。

  • 新功能

    • 添加 jax.block_until_ready ({jax-issue}`#8941)

    • 添加了一个新的调试标志/环境变量 JAX_DUMP_IR_TO=/path。如果设置,JAX 会将它为每个计算生成的 MHLO/HLO IR 转储到给定路径下的文件中。

    • jax.ensure_compile_time_eval 添加到公共 API 中 (#7987)。

    • jax2tf 现在支持标志 jax2tf_associative_scan_reductions,以更改关联归约(例如 jnp.cumsum)的降级,使其在 CPU 和 GPU 上像 JAX 一样运行(使用关联扫描)。有关更多详细信息,请参阅 jax2tf README (#9189)。

jaxlib 0.1.75 (2021 年 12 月 8 日)#

  • 新功能

    • 支持 Python 3.10。

jax 0.2.26 (2021 年 12 月 8 日)#

  • GitHub 提交.

  • 错误修复

    • 超出范围的 jax.ops.segment_sum 索引现在将按照文档中的说明,使用 FILL_OR_DROP 语义进行处理。这主要影响反向模式导数,其中与超出范围的索引对应的梯度现在将返回为 0。( #8634)。

    • jax2tf 将强制转换后的代码对 jax.jit 下的代码片段使用 XLA,例如大多数 jax.numpy 函数 (#7839)。

jaxlib 0.1.74 (2021 年 11 月 17 日)#

  • 启用 GPU 之间的对等复制。以前,GPU 复制是通过主机反弹的,这通常较慢。

  • 为 JAX 添加了实验性的 MLIR Python 绑定。

jax 0.2.25 (2021 年 11 月 10 日)#

  • GitHub 提交.

  • 新功能

    • (实验性) jax.distributed.initialize 公开了多主机 GPU 后端。

    • jax.random.permutation 支持新的 independent 关键字参数 (#8430)

  • 重大更改

    • jax.experimental.stax 移动到 jax.example_libraries.stax

    • jax.experimental.optimizers 移动到 jax.example_libraries.optimizers

  • 新功能

    • 添加了 jax.lax.linalg.qdwh

jax 0.2.24 (2021 年 10 月 19 日)#

  • GitHub 提交.

  • 新功能

    • jax.random.choicejax.random.permutation 现在支持多维数组和一个可选的 axis 参数 (#8158)

  • 重大更改

    • jax.numpy.takejax.numpy.take_along_axis 现在需要类似数组的输入 (请参阅 #7737)

jaxlib 0.1.73 (2021 年 10 月 18 日)#

  • 现在,jaxlib GPU cuda11 wheels 支持多个 cuDNN 版本。

    • cuDNN 8.2 或更高版本。我们建议您在 cuDNN 安装足够新的情况下使用 cuDNN 8.2 wheel,因为它支持其他功能。

    • cuDNN 8.0.5 或更高版本。

  • 重大更改

    • GPU jaxlib 的安装命令如下

      pip install --upgrade pip
      
      # Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
      pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
      
      # Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer.
      pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
      
      # Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer.
      pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
      

jax 0.2.22 (2021 年 10 月 12 日)#

  • GitHub 提交.

  • 重大更改

    • jax.pmap 的静态参数现在必须是可哈希的。

      不可哈希的静态参数长期以来都不允许在 jax.jit 上使用,但仍然允许在 jax.pmap 上使用;jax.pmap 使用对象标识比较不可哈希的静态参数。

      此行为是一个潜在的陷阱,因为使用对象标识比较参数会导致每次对象标识更改时都进行重新编译。相反,我们现在禁止不可哈希的参数:如果 jax.pmap 的用户希望通过对象标识比较静态参数,他们可以在其对象上定义 __hash____eq__ 方法来实现此目的,或者将他们的对象包装在具有对象标识语义的这些操作的对象中。另一种选择是使用 functools.partial 将不可哈希的静态参数封装到函数对象中。

    • jax.util.partial 是一次意外导出,现在已删除。请改用 Python 标准库中的 functools.partial

  • 弃用

    • 函数 jax.ops.index_updatejax.ops.index_add 等已弃用,并将在未来的 JAX 版本中删除。请改用 JAX 数组上的 .at 属性,例如,x.at[idx].set(y)。现在,这些函数会产生 DeprecationWarning

  • 新功能

    • 使用 jaxlib 0.1.72 或更高版本时,用于改进 pmap 分派时间的优化 C++ 代码路径现在是默认设置。可以使用 --experimental_cpp_pmap 标志(或 JAX_CPP_PMAP 环境变量)禁用此功能。

    • jax.numpy.unique 现在支持可选的 fill_value 参数 (#8121)

jaxlib 0.1.72 (2021 年 10 月 12 日)#

  • 重大更改

    • 已放弃对 CUDA 10.2 和 CUDA 10.1 的支持。Jaxlib 现在支持 CUDA 11.1+。

  • 错误修复

    • 修复了 https://github.com/jax-ml/jax/issues/7461,该问题由于 XLA 编译器内部错误的缓冲区别名而导致所有平台上的输出错误。

jax 0.2.21 (2021 年 9 月 23 日)#

  • GitHub 提交.

  • 重大更改

    • jax.api 已删除。jax.api.* 中提供的函数是 jax.* 中函数的别名;请改用 jax.* 中的函数。

    • jax.partialjax.lax.partial 是意外导出,现在已删除。请改用 Python 标准库中的 functools.partial

    • 布尔标量索引现在会引发 TypeError 错误;之前此操作会静默地返回错误的结果 (#7925)。

    • 现在,许多 jax.numpy 函数都要求输入为类数组对象,如果传入列表则会报错 (#7747 #7802 #7907)。有关此更改背后的基本原理,请参阅 #7737

    • 当在诸如 jax.jit 之类的转换内部时,jax.numpy.array 总是会将它产生的数组暂存到跟踪计算中。之前,即使在 jax.jit 修饰符下,jax.numpy.array 有时也会生成设备上的数组。此更改可能会破坏使用 JAX 数组执行必须静态已知的形状或索引计算的代码;解决方法是改用经典的 NumPy 数组执行此类计算。

    • jnp.ndarray 现在是 JAX 数组的真正基类。特别是,这意味着对于标准的 NumPy 数组 xisinstance(x, jnp.ndarray) 现在将返回 False (#7927)。

  • 新功能

jax 0.2.20 (2021 年 9 月 2 日)#

  • GitHub 提交.

  • 重大更改

    • jnp.poly* 函数现在要求输入为类数组对象 (#7732)

    • jnp.unique 和其他类似集合的操作现在要求输入为类数组对象 (#7662)

jaxlib 0.1.71 (2021 年 9 月 1 日)#

  • 重大更改

    • 已删除对 CUDA 11.0 和 CUDA 10.1 的支持。Jaxlib 现在支持 CUDA 10.2 和 CUDA 11.1+。

jax 0.2.19 (2021 年 8 月 12 日)#

  • GitHub 提交.

  • 重大更改

    • 根据弃用策略,已删除对 NumPy 1.17 的支持。请升级到受支持的 NumPy 版本。

    • 已在 JAX 数组上的一些运算符的实现周围添加了 jit 修饰符。这加快了诸如 + 之类的常用运算符的调度时间。

      此更改对大多数用户来说应该是透明的。但是,有一个已知的行为更改,即当将大型整数常量直接传递给 JAX 运算符时(例如,x + 2**40),现在可能会产生错误。解决方法是将常量强制转换为显式类型(例如,np.float64(2**40))。

  • 新功能

    • 改进了 jax2tf 中对需要在数组计算中使用维度大小的操作(例如,jnp.mean)的形状多态性的支持。( #7317 )

  • 错误修复

    • 修复了上一个版本中泄漏的一些跟踪错误 (#7613)

jaxlib 0.1.70 (2021 年 8 月 9 日)#

  • 重大更改

    • 根据弃用策略,已删除对 Python 3.6 的支持。请升级到受支持的 Python 版本。

    • 根据弃用策略,已删除对 NumPy 1.17 的支持。请升级到受支持的 NumPy 版本。

    • host_callback 机制现在为每个本地设备使用一个线程来调用 Python 回调。之前,所有设备都只有一个线程。这意味着现在回调可能会交错调用。与一个设备对应的回调仍将按顺序调用。

jax 0.2.18 (2021 年 7 月 21 日)#

  • GitHub 提交.

  • 重大更改

    • 根据弃用策略,已删除对 Python 3.6 的支持。请升级到受支持的 Python 版本。

    • 最低 jaxlib 版本现在是 0.1.69。

    • 已删除 jax.dlpack.from_dlpack()backend 参数。

  • 新功能

  • 错误修复

    • 收紧了对 lax.argmin 和 lax.argmax 的检查,以确保它们不会与无效的 axis 值或空的缩减维度一起使用。( #7196 )

jaxlib 0.1.69 (2021 年 7 月 9 日)#

  • 修复了 TFRT CPU 后端中的错误,该错误导致结果不正确。

jax 0.2.17 (2021 年 7 月 9 日)#

  • GitHub 提交.

  • 错误修复

    • 默认使用较旧的 “stream_executor” CPU 运行时,以便 jaxlib <= 0.1.68 可以解决 #7229 问题,该问题由于并发问题而导致 CPU 上的错误输出。

  • 新功能

jax 0.2.16 (2021 年 6 月 23 日)#

jax 0.2.15 (2021 年 6 月 23 日)#

  • GitHub 提交.

  • 新功能

    • #7042 启用了 TFRT CPU 后端,显著提高了 CPU 上的调度性能。

    • jax2tf.convert() 支持布尔值的比较运算和 min/max 操作 (#6956)。

    • 新的 SciPy 函数 jax.scipy.special.lpmn_values()

  • 重大更改

  • 错误修复

    • 修复了阻止从 JAX 到 TF 再返回的往返错误的错误:jax2tf.call_tf(jax2tf.convert) ( #6947 )。

jaxlib 0.1.68 (2021 年 6 月 23 日)#

  • 错误修复

    • 修复了 TFRT CPU 后端中的错误,该错误在将 TPU 缓冲区传输到 CPU 时会产生 NaN。

jax 0.2.14 (2021 年 6 月 10 日)#

  • GitHub 提交.

  • 新功能

    • jax2tf.convert() 现在支持 pjitsharded_jit

    • 新的配置选项 JAX_TRACEBACK_FILTERING 控制 JAX 如何过滤回溯。

    • 在足够新的 IPython 版本中,默认情况下启用使用 __tracebackhide__ 的新回溯过滤模式。

    • 即使在算术运算中使用未知维度,jax2tf.convert() 也支持形状多态性,例如,jnp.reshape(-1) ( #6827 )。

    • jax2tf.convert() 在 TF 操作中生成带有位置信息的自定义属性。jax2tf 之后 XLA 生成的代码具有与 JAX/XLA 相同的位置信息。

    • 新的 SciPy 函数 jax.scipy.special.lpmn()

  • 错误修复

    • jax2tf.convert() 现在确保它使用与 JAX 相同的 Python 标量类型规则,并选择 32 位与 64 位计算规则 ( #6883 )。

    • jax2tf.convert() 现在正确地限定 enable_xla 转换参数的范围,使其仅在即时转换期间应用 ( #6720 )。

    • jax2tf.convert() 现在使用 XlaDot TensorFlow 操作转换 lax.dot_general,以便更好地保证与 JAX 数值精度的一致性 ( #6717 )。

    • jax2tf.convert() 现在支持复数的比较运算和 min/max 操作 ( #6892 )。

jaxlib 0.1.67 (2021 年 5 月 17 日)#

jaxlib 0.1.66 (2021 年 5 月 11 日)#

  • 新功能

    • 现在在所有 CUDA 11 版本(11.1 或更高版本)上都支持 CUDA 11.1 wheels。

      NVidia 现在承诺 CUDA 次要版本(从 CUDA 11.1 开始)之间的兼容性。这意味着 JAX 可以发布一个与 CUDA 11.2 和 11.3 兼容的单个 CUDA 11.1 wheel。

      CUDA 11.2(或更高版本)不再有单独的 jaxlib 版本;这些版本请使用 CUDA 11.1 wheel (cuda111)。

    • Jaxlib 现在在 CUDA wheels 中捆绑了 libdevice.10.bc。应该不再需要让 JAX 指向 CUDA 安装目录来查找此文件。

    • jit() 实现添加了对静态关键字参数的自动支持。

    • 添加了对预转换异常跟踪的支持。

    • 初步支持从 jit() 转换的计算中修剪未使用的参数。修剪仍在进行中。

    • 改进了 PyTreeDef 对象的字符串表示。

    • 添加了对 XLA 的可变参数 ReduceWindow 的支持。

  • 错误修复

    • 修复了当大量参数传递给计算时,远程云 TPU 支持中的一个错误。

    • 修复了一个导致 JAX 垃圾回收未被 jit() 转换的函数触发的错误。

jax 0.2.13 (2021 年 5 月 3 日)#

  • GitHub 提交.

  • 新功能

    • 当与 jaxlib 0.1.66 结合使用时,jax.jit() 现在支持静态关键字参数。添加了一个新的 static_argnames 选项,用于将关键字参数指定为静态。

    • jax.nonzero() 有一个新的可选的 size 参数,允许它在 jit 中使用 (#6501)

    • jax.numpy.unique() 现在支持 axis 参数 (#6532)。

    • jax.experimental.host_callback.call() 现在支持 pjit.pjit (#6569)。

    • 添加了 jax.scipy.linalg.eigh_tridiagonal(),用于计算三对角矩阵的特征值。目前仅支持特征值。

    • 异常中过滤的和未过滤的堆栈跟踪的顺序已更改。从 JAX 转换的代码抛出的异常的追溯现在是经过过滤的,一个包含原始跟踪的 UnfilteredStackTrace 异常作为过滤的异常的 __cause__。过滤的堆栈跟踪现在也适用于 Python 3.6。

    • 如果被反向模式自动微分转换的代码抛出异常,JAX 现在会尝试将一个 JaxStackTraceBeforeTransformation 对象作为异常的 __cause__ 附加,该对象包含在前向传递中创建原始操作的堆栈跟踪。需要 jaxlib 0.1.66。

  • 重大更改

    • 以下函数名称已更改。仍然存在别名,因此这不应该破坏现有代码,但别名最终将被删除,因此请更改您的代码。

    • 类似地,local_devices() 的参数已从 host_id 重命名为 process_index

    • 除了函数之外,传递给 jax.jit() 的参数现在被标记为仅限关键字。此更改是为了防止在向 jit 添加参数时意外中断。

  • 错误修复

    • 当存在整数输入函数的梯度时,jax2tf.convert() 现在可以工作了 (#6360)。

    • 修复了与捕获的 tf.Variable 一起使用时 jax2tf.call_tf() 中的断言失败问题 (#6572)。

jaxlib 0.1.65 (2021 年 4 月 7 日)#

jax 0.2.12 (2021 年 4 月 1 日)#

  • GitHub 提交.

  • 新功能

  • 重大更改

    • 最低 jaxlib 版本现在是 0.1.64。

    • 一些分析器 API 名称已更改。仍然存在别名,因此这不应该破坏现有代码,但别名最终将被删除,因此请更改您的代码。

    • Omnistaging 不能再被禁用。有关更多信息,请参阅 omnistaging

    • 大于最大 int64 值的 Python 整数现在在所有情况下都会导致溢出,而不是在某些情况下静默转换为 uint64 (#6047)。

    • 在 X64 模式之外,超出 int32 可表示范围的 Python 整数现在会导致 OverflowError,而不是静默截断其值。

  • 错误修复

    • host_callback 现在支持参数和结果中的空数组 (#6262)。

    • jax.random.randint() 剪辑而不是环绕超出范围的限制,并且现在可以在指定 dtype 的完整范围内生成整数 (#5868)

jax 0.2.11 (2021 年 3 月 23 日)#

  • GitHub 提交.

  • 新功能

    • #6112 添加了上下文管理器:jax.enable_checksjax.check_tracer_leaksjax.debug_nansjax.debug_infsjax.log_compiles

    • #6085 添加了 jnp.delete

  • 错误修复

    • #6136jax.flatten_util.ravel_pytree 推广为处理整数 dtype。

    • #6129 修复了处理某些常量(如 enum.IntEnums)的错误

    • #6145 修复了不完整 beta 函数的批处理问题

    • #6014 修复了跟踪期间的 H2D 传输

    • #6165 避免了将一些大型 Python 整数转换为浮点数时出现 OverflowErrors

  • 重大更改

    • 最低 jaxlib 版本现在是 0.1.62。

jaxlib 0.1.64 (2021 年 3 月 18 日)#

jaxlib 0.1.63 (2021 年 3 月 17 日)#

jax 0.2.10 (2021 年 3 月 5 日)#

  • GitHub 提交.

  • 新功能

    • jax.scipy.stats.chi2() 现在可以作为具有 logpdf 和 pdf 方法的分布使用。

    • jax.scipy.stats.betabinom() 现在可以作为具有 logpmf 和 pmf 方法的分布使用。

    • 添加了 jax.experimental.jax2tf.call_tf() 以从 JAX 调用 TensorFlow 函数 (#5627) 和 README)。

    • 扩展了 lax.pad 的批处理规则,以支持填充值的批处理。

  • 错误修复

  • 重大更改

    • 调整了 JAX 的提升规则,以使提升更加一致,并且与 JIT 不变。特别是,二进制操作现在可以在适当的情况下产生弱类型的值。更改的主要用户可见效果是,某些操作产生的结果精度与以前不同;例如,表达式 jnp.bfloat16(1) + 0.1 * jnp.arange(10) 之前返回一个 float64 数组,现在返回一个 bfloat16 数组。JAX 的类型提升行为在 类型提升语义 中进行了描述。

    • jax.numpy.linspace() 现在计算整数值的向下取整,即向 -inf 方向舍入,而不是向 0 舍入。此更改是为了与 NumPy 1.20.0 匹配。

    • jax.numpy.i0() 不再接受复数。之前,该函数计算复数参数的绝对值。此更改是为了与 NumPy 1.20.0 的语义匹配。

    • 多个 jax.numpy 函数不再接受元组或列表来代替数组参数:jax.numpy.pad()、:funcjax.numpy.raveljax.numpy.repeat()jax.numpy.reshape()。一般来说,jax.numpy 函数应使用标量或数组参数。

jaxlib 0.1.62 (2021年3月9日)#

  • 新功能

    • 默认情况下,现在构建 jaxlib wheels 以要求 x86-64 机器上的 AVX 指令。如果要在不支持 AVX 的机器上使用 JAX,可以使用 --target_cpu_features 标志来使用源代码构建 jaxlib。 --target_cpu_features 也取代了 --enable_march_native

jaxlib 0.1.61 (2021年2月12日)#

jaxlib 0.1.60 (2021年2月3日)#

  • 错误修复

    • 修复了将 CPU DeviceArrays 转换为 NumPy 数组时的内存泄漏。内存泄漏存在于 jaxlib 版本 0.1.58 和 0.1.59 中。

    • boolint8uint8 现在被认为是安全地转换为 bfloat16 NumPy 扩展类型。

jax 0.2.9 (2021年1月26日)#

jaxlib 0.1.59 (2021年1月15日)#

jax 0.2.8 (2021年1月12日)#

  • GitHub 提交.

  • 新功能

    • 添加 jax.closure_convert() 以用于高阶自定义导数函数。 (#5244)

    • 添加 jax.experimental.host_callback.call() 以在主机上调用自定义 Python 函数并将结果返回到设备计算。 (#5243)

  • 错误修复

    • jax.numpy.arccosh 现在为复数输入返回与 numpy.arccosh 相同的分支 (#5156)

    • host_callback.id_tap 现在也适用于 jax.pmap。 对于 id_tapid_print 存在一个可选参数,用于请求将从中提取值的设备作为关键字参数传递给 tap 函数 (#5182)。

  • 重大更改

    • jax.numpy.pad 现在接受关键字参数。位置参数 constant_values 已被删除。此外,传递不支持的关键字参数会引发错误。

    • jax.experimental.host_callback.id_tap() 的更改 (#5243)

      • 删除了对 jax.experimental.host_callback.id_tap()kwargs 的支持。(此支持已弃用几个月。)

      • 更改了 jax.experimental.host_callback.id_print() 的元组打印,以使用“(”代替“[”。

      • 更改了在 JVP 存在时 jax.experimental.host_callback.id_print() 以打印原始值和切线对。以前,原始值和切线有两个单独的打印操作。

      • host_callback.outfeed_receiver 已被删除(它不是必需的,并且几个月前已弃用)。

  • 新功能

    • 用于调试 inf 的新标志,类似于 NaN 的标志 (#5224)。

jax 0.2.7 (2020年12月4日)#

  • GitHub 提交.

  • 新功能

    • 添加 jax.device_put_replicated

    • jax.experimental.sharded_jit 添加多主机支持

    • 添加对由 jax.numpy.linalg.eig 计算的特征值进行微分的支持

    • 添加对在 Windows 平台上构建的支持

    • jax.pmap 中添加对通用 in_axes 和 out_axes 的支持

    • jax.numpy.linalg.slogdet 添加复数支持

  • 错误修复

    • 修复了 jax.numpy.sinc 在零处的更高阶导数

    • 修复了转置规则中围绕符号零的一些难以命中的错误

  • 重大更改

    • jax.experimental.optix 已被删除,以支持独立的 optax Python 包。

    • 现在使用非元组序列对 JAX 数组进行索引会引发 TypeError。自 v1.16 以来,这种类型的索引在 Numpy 中已弃用,并且自 v0.2.4 以来在 JAX 中已弃用。请参阅 #4564

jax 0.2.6 (2020年11月18日)#

  • GitHub 提交.

  • 新功能

    • 为 jax.experimental.jax2tf 转换器添加对形状多态跟踪的支持。请参阅 README.md

  • 重大更改清理

    • 对于 jax.jit 和 xla_computation 的不可哈希静态参数,引发错误。请参阅 cb48f42

    • 改进类型提升行为的一致性 (#4744)

      • 将复数 Python 标量添加到 JAX 浮点数时,会尊重 JAX 浮点数的精度。例如,jnp.float32(1) + 1j 现在返回 complex64,而以前返回 complex128

      • 涉及 uint64、有符号 int 和第三种类型的 3 个或更多项的类型提升结果现在与参数的顺序无关。例如:jnp.result_type(jnp.uint64, jnp.int64, jnp.float16)jnp.result_type(jnp.float16, jnp.uint64, jnp.int64) 都返回 float16,而以前第一个返回 float64,第二个返回 float16

    • (未记录的)jax.lax_linalg 线性代数模块的内容现在公开作为 jax.lax.linalg 公开。

    • 现在 jax.random.PRNGKey 在 JIT 编译内外产生相同的结果 (#4877)。这需要在一些特定情况下更改给定种子的结果

      • 使用 jax_enable_x64=False,作为 Python 整数传递的负种子现在在 JIT 模式外返回不同的结果。例如,jax.random.PRNGKey(-1) 之前返回 [4294967295, 4294967295],现在返回 [0, 4294967295]。这与 JIT 中的行为匹配。

      • 超出 JIT 外的 int64 可表示范围的种子现在会导致 OverflowError 而不是 TypeError。这与 JIT 中的行为匹配。

      要恢复以前为 JIT 外的 jax_enable_x64=False 的负整数返回的密钥,可以使用

      key = random.PRNGKey(-1).at[0].set(0xFFFFFFFF)
      
    • 现在,当尝试访问已删除的 DeviceArray 的值时,DeviceArray 会引发 RuntimeError 而不是 ValueError

jaxlib 0.1.58 (约 2021 年 1 月 12 日)#

  • 修复了一个错误,该错误意味着 JAX 有时会返回平台特定的类型(例如,np.cint)而不是标准类型(例如,np.int32)。 (#4903)

  • 修复了在恒定折叠某些 int16 操作时发生的崩溃。 (#4971)

  • pytree.flatten() 添加了 is_leaf 谓词。

jaxlib 0.1.57 (2020年11月12日)#

  • 修复了 GPU wheels 中的 manylinux2010 合规性问题。

  • 将 CPU FFT 实现从 Eigen 切换到 PocketFFT。

  • 修复了 bfloat16 值的哈希未正确初始化并可能更改的错误 (#4651)。

  • 添加了在将数组传递给 DLPack 时保留所有权的支持 (#4636)。

  • 修复了批处理三角求解的大小大于 128 但不是 128 的倍数的错误。

  • 修复了在多个 GPU 上执行并发 FFT 时发生的错误 (#3518)。

  • 修复了分析器中缺少工具的错误 (#4427)。

  • 放弃了对 CUDA 10.0 的支持。

JAX 0.2.5 (2020 年 10 月 27 日)#

JAX 0.2.4 (2020 年 10 月 19 日)#

  • GitHub 提交.

  • 改进

    • 为 jax.experimental.host_callback 添加 remat 的支持。请参阅 #4608

  • 弃用

    • 使用非元组序列进行索引现在已被弃用,遵循 Numpy 中类似的弃用。在未来的版本中,这将导致 TypeError。请参阅 #4564

jaxlib 0.1.56 (2020 年 10 月 14 日)#

JAX 0.2.3 (2020 年 10 月 14 日)#

  • GitHub 提交.

  • 如此迅速地发布另一个版本的原因是,我们需要暂时回滚一个新的 jit 快速路径,同时我们正在研究性能下降的问题

JAX 0.2.2 (2020 年 10 月 13 日)#

JAX 0.2.1 (2020 年 10 月 6 日)#

  • GitHub 提交.

  • 改进

    • 作为全阶段化的好处,即使 jax.experimental.host_callback.id_print()/ jax.experimental.host_callback.id_tap() 的结果未在计算中使用,host_callback 函数也会(按程序顺序)执行。

JAX (0.2.0) (2020 年 9 月 23 日)#

JAX (0.1.77) (2020 年 9 月 15 日)#

  • 重大更改

    • 用于 jax.experimental.host_callback.id_tap() 的新的简化接口 (#4101)

jaxlib 0.1.55 (2020 年 9 月 8 日)#

  • 更新 XLA

    • 修复 DLPackManagedTensorToBuffer 中的错误 (#4196)

JAX 0.1.76 (2020 年 9 月 8 日)#

JAX 0.1.75 (2020 年 7 月 30 日)#

  • GitHub 提交.

  • 错误修复

    • 使 jnp.abs() 适用于无符号输入 (#3914)

  • 改进

    • “全阶段化”行为在标志后添加,默认禁用 (#3370)

JAX 0.1.74 (2020 年 7 月 29 日)#

  • GitHub 提交.

  • 新功能

    • BFGS (#3101)

    • TPU 支持半精度算术 (#3878)

  • 错误修复

    • 防止一些意外的 dtype 警告 (#3874)

    • 修复自定义导数中的多线程错误 (#3845, #3869)

  • 改进

    • 更快的 searchsorted 实现 (#3873)

    • jax.numpy 排序算法的更好的测试覆盖率 (#3836)

jaxlib 0.1.52 (2020 年 7 月 22 日)#

  • 更新 XLA。

JAX 0.1.73 (2020 年 7 月 22 日)#

  • GitHub 提交.

  • 现在最低 jaxlib 版本是 0.1.51。

  • 新功能

    • jax.image.resize。 (#3703)

    • hfft 和 ihfft (#3664)

    • jax.numpy.intersect1d (#3726)

    • jax.numpy.lexsort (#3812)

    • lax.scanscan 原语在降低到 XLA 时支持用于循环展开的 unroll 参数(#3738)。

  • 错误修复

    • 修复缩减重复轴错误 (#3618)

    • 修复大小为 0 的输入维度 lax.pad 的形状规则 (#3608)

    • 使 psum 转置处理零余切 (#3653)

    • 修复在对大小为 0 的轴进行 reduce-prod 的 JVP 时出现的形状错误 (#3729)

    • 支持通过 jax.lax.all_to_all 进行微分 (#3733)

    • 解决 jax.scipy.special.zeta 中的 nan 问题 (#3777)

  • 改进

    • jax2tf 的许多改进

    • 使用单次可变缩减重新实现 argmin/argmax。 (#3611)

    • 默认启用 XLA SPMD 分区。 (#3151)

    • 添加对 0d 转置卷积的支持 (#3643)

    • 使 LU 梯度适用于低秩矩阵 (#3610)

    • 在 jet 中支持 multiple_results 和自定义 JVP (#3657)

    • 概括 reduce-window 填充以支持 (lo, hi) 对。 (#3728)

    • 在 CPU 和 GPU 上实现复数卷积。 (#3735)

    • 使 jnp.take 适用于空数组的空切片。 (#3751)

    • 放宽 dot_general 的维度排序规则。 (#3778)

    • 启用 GPU 的缓冲区捐赠。 (#3800)

    • 为 reduce window 操作添加对基本扩张和窗口扩张的支持… (#3803)

jaxlib 0.1.51 (2020 年 7 月 2 日)#

  • 更新 XLA。

  • 为 host_callback 添加新的运行时支持。

JAX 0.1.72 (2020 年 6 月 28 日)#

  • GitHub 提交.

  • 错误修复

    • 修复了在之前的版本中引入的 odeint 错误,请参阅 #3587

JAX 0.1.71 (2020 年 6 月 25 日)#

  • GitHub 提交.

  • 现在最低 jaxlib 版本是 0.1.48。

  • 错误修复

    • 允许 jax.experimental.ode.odeint 动力学函数关闭相对于我们正在微分的值 #3562

jaxlib 0.1.50 (2020 年 6 月 25 日)#

  • 添加对 CUDA 11.0 的支持。

  • 删除对 CUDA 9.2 的支持(我们仅维护对最后四个 CUDA 版本的支持。)

  • 更新 XLA。

jaxlib 0.1.49 (2020 年 6 月 19 日)#

jaxlib 0.1.48 (2020 年 6 月 12 日)#

  • 新功能

    • 添加对快速回溯收集的支持。

    • 添加对设备上堆分析的初步支持。

    • bfloat16 类型实现 np.nextafter

    • CPU 和 GPU 上的 FFT 的 Complex128 支持。

  • 错误修复

    • 改进了 GPU 上 float64 tanh 的精度。

    • GPU 上的 float64 scatter 快得多。

    • CPU 上的复数矩阵乘法应该快得多。

    • CPU 上的稳定排序现在应该真正稳定了。

    • CPU 后端中的并发错误修复。

JAX 0.1.70 (2020 年 6 月 8 日)#

  • GitHub 提交.

  • 新功能

    • lax.switch 引入了具有多个分支的索引条件,以及 cond 原语的泛化 #3318

JAX 0.1.69 (2020 年 6 月 3 日)#

JAX 0.1.68 (2020 年 5 月 21 日)#

  • GitHub 提交.

  • 新功能

    • lax.cond() 支持单操作数形式,该形式作为两个分支的参数 #2993

  • 值得注意的更改

    • jax.experimental.host_callback.id_tap() 原语的 transforms 关键字的格式已更改 #3132

JAX 0.1.67 (2020 年 5 月 12 日)#

  • GitHub 提交.

  • 新功能

    • 使用 axis_index_groups 支持在 pmapped 轴的子集上进行缩减 #2382

    • 实验性支持从已编译的代码中打印和调用主机端 Python 函数。请参阅 id_print 和 id_tap (#3006)。

  • 值得注意的更改

    • jax.numpy 导出的名称的可见性已收紧。这可能会破坏以前意外使用导出名称的代码。

jaxlib 0.1.47 (2020 年 5 月 8 日)#

  • 修复 outfeed 的崩溃。

JAX 0.1.66 (2020 年 5 月 5 日)#

jaxlib 0.1.46 (2020 年 5 月 5 日)#

  • 修复 Mac OS X 上线性代数函数的崩溃 (#432)。

  • 修复了在操作系统或虚拟机禁用 AVX512 指令时,由于使用 AVX512 指令而导致的非法指令崩溃 (#2906)。

JAX 0.1.65 (2020 年 4 月 30 日)#

  • GitHub 提交.

  • 新功能

    • 奇异矩阵的行列式的微分 #2809

  • 错误修复

    • 修复具有时间相关动力学的 ODE 的时间 odeint() 微分 #2817,还添加 ODE CI 测试。

    • 修复 lax_linalg.qr() 微分 #2867

jaxlib 0.1.45 (2020 年 4 月 21 日)#

  • 修复段错误:#2755

  • 将 Sort HLO 上的 is_stable 选项管道传输到 Python。

JAX 0.1.64 (2020 年 4 月 21 日)#

jaxlib 0.1.44 (2020 年 4 月 16 日)#

  • 修复了一个 bug,即如果存在多个不同型号的 GPU,JAX 只会编译适用于第一个 GPU 的程序。

  • 修复了 batch_group_count 卷积的 bug。

  • 为更多 GPU 版本添加了预编译的 SASS,以避免启动 PTX 编译挂起。

jax 0.1.63 (2020 年 4 月 12 日)#

  • GitHub 提交.

  • #2026 添加了 jax.custom_jvpjax.custom_vjp,请参阅 教程笔记本。 弃用了 jax.custom_transforms 并从文档中删除(尽管它仍然有效)。

  • 添加 scipy.sparse.linalg.cg #2566

  • 更改了 Tracers 的打印方式,以显示更多有用的调试信息 #2591

  • 使 jax.numpy.isclose 正确处理 naninf #2501

  • jax.experimental.jet 添加了几个新规则 #2537

  • 修复了未提供 scale/center 时的 jax.experimental.stax.BatchNorm

  • 修复了 jax.numpy.einsum 中一些缺失的广播情况 #2512

  • 用并行前缀扫描实现了 jax.numpy.cumsumjax.numpy.cumprod #2596,并使 reduce_prod 可微分到任意阶 #2597

  • conv_general_dilated 添加 batch_group_count #2635

  • test_util.check_grads 添加了文档字符串 #2656

  • 添加 callback_transform #2665

  • 实现了 rollaxisconvolve/correlate 1d & 2d、copysigntruncroots 以及 quantile/percentile 插值选项。

jaxlib 0.1.43 (2020 年 3 月 31 日)#

  • 修复了 GPU 上 Resnet-50 的性能回归问题。

jax 0.1.62 (2020 年 3 月 21 日)#

  • GitHub 提交.

  • JAX 已停止支持 Python 3.5。请升级到 Python 3.6 或更高版本。

  • 删除了内部函数 lax._safe_mul,该函数实现了 0. * nan == 0. 的约定。此更改意味着某些程序在微分时会产生 nan,而之前会产生正确的值,尽管它确保为其他程序生成 nan 而不是静默的错误结果。有关详细信息,请参阅 #2447 和 #1052。

  • 添加了 all_gather 并行便利函数。

  • 核心代码中添加了更多类型注释。

jaxlib 0.1.42 (2020 年 3 月 19 日)#

  • 由于 API 不兼容,jaxlib 0.1.41 破坏了云 TPU 支持。此版本再次修复了它。

  • JAX 已停止支持 Python 3.5。请升级到 Python 3.6 或更高版本。

jax 0.1.61 (2020 年 3 月 17 日)#

  • GitHub 提交.

  • 修复了 Python 3.5 支持。这将是支持 Python 3.5 的最后一个 JAX 或 jaxlib 版本。

jax 0.1.60 (2020 年 3 月 17 日)#

  • GitHub 提交.

  • 新功能

    • jax.pmap() 具有 static_broadcast_argnums 参数,允许用户指定应视为编译时常量并应广播到所有设备的参数。它的工作方式类似于 jax.jit() 中的 static_argnums

    • 改进了当追踪器错误地保存在全局状态中时的错误消息。

    • 添加了 jax.nn.one_hot() 实用函数。

    • 添加了 jax.experimental.jet,用于指数级更快的高阶自动微分。

    • jax.lax.broadcast_in_dim() 的参数添加了更多正确性检查。

  • 现在最低 jaxlib 版本为 0.1.41。

jaxlib 0.1.40 (2020 年 3 月 4 日)#

  • 在 Jaxlib 中添加了对 TensorFlow 分析器的实验性支持,允许从 TensorBoard 跟踪 CPU 和 GPU 计算。

  • 包含通过 NCCL 通信的多主机 GPU 计算的原型支持。

  • 提高了 GPU 上 NCCL 集合的性能。

  • 添加了 TopK、CustomCallWithoutLayout、CustomCallWithLayout、IGammaGradA 和 RandomGamma 实现。

  • 支持在 XLA 编译时已知的设备分配。

jax 0.1.59 (2020 年 2 月 11 日)#

  • GitHub 提交.

  • 重大更改

    • 现在最低 jaxlib 版本为 0.1.38。

    • 通过删除 Jaxpr.freevarsJaxpr.bound_subjaxprs 简化了 Jaxpr。调用原语(xla_callxla_pmapsharded_callremat_call)获得了一个新参数 call_jaxpr,它带有完全封闭的(没有 constvars)jaxpr。此外,为原语添加了一个新字段 call_primitive

  • 新功能

    • lax.cond 的反向模式自动微分(例如 grad),使其现在在两种模式下都可微分 (#2091)

    • JAX 现在支持 DLPack,它允许以零拷贝的方式与其他库(例如 PyTorch)共享 CPU 和 GPU 数组。

    • JAX GPU DeviceArrays 现在支持 __cuda_array_interface__,这是另一个用于与其他库(例如 CuPy 和 Numba)共享 GPU 数组的零拷贝协议。

    • JAX CPU 设备缓冲区现在实现了 Python 缓冲区协议,允许在 JAX 和 NumPy 之间进行零拷贝缓冲区共享。

    • 添加了 JAX_SKIP_SLOW_TESTS 环境变量以跳过已知较慢的测试。

jaxlib 0.1.39 (2020 年 2 月 11 日)#

  • 更新 XLA。

jaxlib 0.1.38 (2020 年 1 月 29 日)#

  • 不再支持 CUDA 9.0。

  • 默认情况下现在构建 CUDA 10.2 wheels。

jax 0.1.58 (2020 年 1 月 28 日)#

值得注意的 bug 修复#

  • 随着 Python 3 的升级,JAX 不再依赖 fastcache,这应该有助于安装。