jax.example_libraries.optimizers 模块#

如何使用 JAX 编写优化器的示例。

您可能并不想导入这个模块!此库中的优化器仅作为示例。如果您正在寻找功能齐全的优化器库,请考虑使用 Optax

此模块包含一些方便的优化器定义,特别是初始化和更新函数,它们可以与 ndarray 或任意嵌套的 ndarray 元组/列表/字典一起使用。

优化器被建模为一个 (init_fun, update_fun, get_params) 函数三元组,其中组件函数具有以下签名

init_fun(params)

Args:
  params: pytree representing the initial parameters.

Returns:
  A pytree representing the initial optimizer state, which includes the
  initial parameters and may also include auxiliary values like initial
  momentum. The optimizer state pytree structure generally differs from that
  of `params`.
update_fun(step, grads, opt_state)

Args:
  step: integer representing the step index.
  grads: a pytree with the same structure as `get_params(opt_state)`
    representing the gradients to be used in updating the optimizer state.
  opt_state: a pytree representing the optimizer state to be updated.

Returns:
  A pytree with the same structure as the `opt_state` argument representing
  the updated optimizer state.
get_params(opt_state)

Args:
  opt_state: pytree representing an optimizer state.

Returns:
  A pytree representing the parameters extracted from `opt_state`, such that
  the invariant `params == get_params(init_fun(params))` holds true.

请注意,优化器的实现形式在 opt_state 方面具有很大的灵活性:它只需要是 JaxTypes 的 pytree(以便它可以传递给 api.py 中定义的 JAX 转换),并且它必须可被 update_fun 和 get_params 使用。

使用示例

opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
opt_state = opt_init(params)

def step(step, opt_state):
  value, grads = jax.value_and_grad(loss_fn)(get_params(opt_state))
  opt_state = opt_update(step, grads, opt_state)
  return value, opt_state

for i in range(num_steps):
  value, opt_state = step(i, opt_state)
class jax.example_libraries.optimizers.JoinPoint(subtree)[source]#

基类:object

标记两个连接(嵌套)的 pytree 之间的边界。

class jax.example_libraries.optimizers.Optimizer(init_fn, update_fn, params_fn)[source]#

基类:NamedTuple

参数:
  • init_fn (InitFn)

  • update_fn (UpdateFn)

  • params_fn (ParamsFn)

init_fn: InitFn#

字段编号 0 的别名

params_fn: ParamsFn#

字段编号 2 的别名

update_fn: UpdateFn#

字段编号 1 的别名

class jax.example_libraries.optimizers.OptimizerState(packed_state, tree_def, subtree_defs)#

基类:tuple

packed_state#

字段编号 0 的别名

subtree_defs#

字段编号 2 的别名

tree_def#

字段编号 1 的别名

jax.example_libraries.optimizers.adagrad(step_size, momentum=0.9)[source]#

为 Adagrad 构建优化器三元组。

用于在线学习和随机优化的自适应子梯度方法:http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf

参数:
  • step_size – 正标量,或表示步长计划的可调用对象,该计划将迭代索引映射到正标量。

  • momentum – 可选,动量的正标量值

返回:

一个 (init_fun, update_fun, get_params) 三元组。

jax.example_libraries.optimizers.adam(step_size, b1=0.9, b2=0.999, eps=1e-08)[source]#

为 Adam 构建优化器三元组。

参数:
  • step_size – 正标量,或表示步长计划的可调用对象,该计划将迭代索引映射到正标量。

  • b1 – 可选,beta_1 的正标量值,即第一动量估计的指数衰减率(默认值 0.9)。

  • b2 – 可选,beta_2 的正标量值,即第二动量估计的指数衰减率(默认值 0.999)。

  • eps – 可选,epsilon 的正标量值,一个用于数值稳定性的较小常量(默认值 1e-8)。

返回:

一个 (init_fun, update_fun, get_params) 三元组。

jax.example_libraries.optimizers.adamax(step_size, b1=0.9, b2=0.999, eps=1e-08)[source]#

为 AdaMax(一种基于无穷范数的 Adam 变体)构建优化器三元组。

参数:
  • step_size – 正标量,或表示步长计划的可调用对象,该计划将迭代索引映射到正标量。

  • b1 – 可选,beta_1 的正标量值,即第一动量估计的指数衰减率(默认值 0.9)。

  • b2 – 可选,beta_2 的正标量值,即第二动量估计的指数衰减率(默认值 0.999)。

  • eps – 可选,epsilon 的正标量值,一个用于数值稳定性的较小常量(默认值 1e-8)。

返回:

一个 (init_fun, update_fun, get_params) 三元组。

jax.example_libraries.optimizers.clip_grads(grad_tree, max_norm)[source]#

将存储为数组 pytree 的梯度裁剪为最大范数 max_norm

jax.example_libraries.optimizers.constant(step_size)[source]#
返回类型:

Schedule

jax.example_libraries.optimizers.exponential_decay(step_size, decay_steps, decay_rate)[source]#
jax.example_libraries.optimizers.inverse_time_decay(step_size, decay_steps, decay_rate, staircase=False)[source]#
jax.example_libraries.optimizers.l2_norm(tree)[source]#

计算数组 pytree 的 l2 范数。用于权重衰减。

jax.example_libraries.optimizers.make_schedule(scalar_or_schedule)[source]#
参数:

scalar_or_schedule (float | Schedule)

返回类型:

Schedule

jax.example_libraries.optimizers.momentum(step_size, mass)[源代码]#

构建带动量的 SGD 优化器三元组。

参数:
  • step_size (Schedule) – 正标量,或一个可调用对象,表示将迭代索引映射到正标量的步长调度。

  • mass (float) – 表示动量系数的正标量。

返回:

一个 (init_fun, update_fun, get_params) 三元组。

jax.example_libraries.optimizers.nesterov(step_size, mass)[源代码]#

构建带 Nesterov 动量的 SGD 优化器三元组。

参数:
  • step_size (Schedule) – 正标量,或一个可调用对象,表示将迭代索引映射到正标量的步长调度。

  • mass (float) – 表示动量系数的正标量。

返回:

一个 (init_fun, update_fun, get_params) 三元组。

jax.example_libraries.optimizers.optimizer(opt_maker)[源代码]#

用于使为数组定义的优化器推广到容器的装饰器。

使用此装饰器,您可以编写仅对单个数组进行操作的 init、update 和 get_params 函数,并将它们转换为对参数的 pytree 进行操作的相应函数。有关示例,请参见 optimizers.py 中定义的优化器。

参数:

opt_maker (Callable[..., tuple[Callable[[Params], State], Callable[[Step, Updates, Params], Params], Callable[[State], Params]]]) –

一个函数,返回一个 (init_fun, update_fun, get_params) 三元函数,这些函数可能仅适用于 ndarray,如

init_fun :: ndarray -> OptStatePytree ndarray
update_fun :: OptStatePytree ndarray -> OptStatePytree ndarray
get_params :: OptStatePytree ndarray -> ndarray

返回:

一个 (init_fun, update_fun, get_params) 三元函数,这些函数适用于任意的 pytree,如

init_fun :: ParameterPytree ndarray -> OptimizerState
update_fun :: OptimizerState -> OptimizerState
get_params :: OptimizerState -> ParameterPytree ndarray

返回的函数使用的 OptimizerState pytree 类型与 ParameterPytree (OptStatePytree ndarray) 同构,但可能会将状态存储为例如部分展平的数据结构以提高性能。

返回类型:

Callable[…, Optimizer]

jax.example_libraries.optimizers.pack_optimizer_state(marked_pytree)[源代码]#

将标记的 pytree 转换为 OptimizerState。

unpack_optimizer_state 的逆操作。将外部 pytree 的叶子表示为 JoinPoints 的标记 pytree 转换回 OptimizerState。此函数旨在在反序列化优化器状态时使用。

参数:

marked_pytree – 一个包含 JoinPoint 叶子的 pytree,这些叶子持有更多的 pytree。

返回:

与输入参数等效的 OptimizerState。

jax.example_libraries.optimizers.piecewise_constant(boundaries, values)[源代码]#
参数:
  • boundaries (Any)

  • values (Any)

jax.example_libraries.optimizers.polynomial_decay(step_size, decay_steps, final_step_size, power=1.0)[源代码]#
jax.example_libraries.optimizers.rmsprop(step_size, gamma=0.9, eps=1e-08)[源代码]#

构建 RMSProp 的优化器三元组。

参数:

step_size – 正标量,或一个可调用对象,表示将迭代索引映射到正标量的步长调度。gamma:衰减参数。eps:Epsilon 参数。

返回:

一个 (init_fun, update_fun, get_params) 三元组。

jax.example_libraries.optimizers.rmsprop_momentum(step_size, gamma=0.9, eps=1e-08, momentum=0.9)[源代码]#

构建带动量的 RMSProp 优化器三元组。

此优化器与 rmsprop 优化器分开,因为它需要跟踪额外的参数。

参数:
  • step_size – 正标量,或表示步长计划的可调用对象,该计划将迭代索引映射到正标量。

  • gamma – 衰减参数。

  • eps – Epsilon 参数。

  • momentum – 动量参数。

返回:

一个 (init_fun, update_fun, get_params) 三元组。

jax.example_libraries.optimizers.sgd(step_size)[源代码]#

构建随机梯度下降的优化器三元组。

参数:

step_size – 正标量,或表示步长计划的可调用对象,该计划将迭代索引映射到正标量。

返回:

一个 (init_fun, update_fun, get_params) 三元组。

jax.example_libraries.optimizers.sm3(step_size, momentum=0.9)[源代码]#

构建 SM3 的优化器三元组。

用于大规模学习的内存高效自适应优化。 https://arxiv.org/abs/1901.11150

参数:
  • step_size – 正标量,或表示步长计划的可调用对象,该计划将迭代索引映射到正标量。

  • momentum – 可选,动量的正标量值

返回:

一个 (init_fun, update_fun, get_params) 三元组。

jax.example_libraries.optimizers.unpack_optimizer_state(opt_state)[源代码]#

将 OptimizerState 转换为标记的 pytree。

将 OptimizerState 转换为标记的 pytree,其中外部 pytree 的叶子表示为 JoinPoints,以避免丢失信息。此函数旨在在序列化优化器状态时使用。

参数:

opt_state – 一个 OptimizerState

返回:

一个具有 JoinPoint 叶子的 pytree,其中包含第二级 pytree。