等级提升警告

等级提升警告#

NumPy 广播规则 允许自动将参数从一个等级(数组轴的数量)提升到另一个等级。这种行为在预期时可能很方便,但也可能导致意想不到的错误,其中无声的等级提升掩盖了潜在的形状错误。

这是一个等级提升的例子

>>> import numpy as np
>>> x = np.arange(12).reshape(4, 3)
>>> y = np.array([0, 1, 0])
>>> x + y
array([[ 0,  2,  2],
       [ 3,  5,  5],
       [ 6,  8,  8],
       [ 9, 11, 11]])

为了避免潜在的意外情况,jax.numpy 可配置,以便需要等级提升的表达式会导致警告、错误,或者可以像常规 NumPy 一样被允许。配置选项名为 jax_numpy_rank_promotion,它可以取字符串值 allowwarnraise。默认设置是 allow,它允许等级提升,不会发出警告或错误。 raise 设置在等级提升时引发错误,而 warn 在第一次发生等级提升时发出警告。

可以使用 jax.numpy_rank_promotion() 上下文管理器在本地启用或禁用等级提升

with jax.numpy_rank_promotion("warn"):
  z = x + y

这种配置也可以通过几种方式在全局范围内设置。一种是在代码中使用 jax.config

import jax
jax.config.update("jax_numpy_rank_promotion", "warn")

您也可以使用环境变量 JAX_NUMPY_RANK_PROMOTION 设置选项,例如 JAX_NUMPY_RANK_PROMOTION='warn'。最后,在使用 absl-py 时,可以使用命令行标志设置选项。