jax.lax
模块#
jax.lax
是一个原始操作库,它是 jax.numpy
等库的基础。转换规则(如 JVP 和批处理规则)通常定义为对 jax.lax
原始操作的转换。
许多原始操作是对等效 XLA 操作的薄包装,这些操作在 XLA 操作语义文档中进行了描述。在少数情况下,JAX 会偏离 XLA,通常是为了确保操作集在 JVP 和转置规则的操作下是闭合的。
如果可能,请优先使用 jax.numpy
等库,而不是直接使用 jax.lax
。jax.numpy
API 遵循 NumPy,因此比 jax.lax
API 更稳定,更不容易更改。
运算符#
|
逐元素绝对值:\(|x|\)。 |
|
逐元素反余弦:\(\mathrm{acos}(x)\)。 |
|
逐元素反双曲余弦:\(\mathrm{acosh}(x)\)。 |
|
逐元素加法:\(x + y\)。 |
|
合并一个或多个 XLA 令牌值。 |
|
以近似方式返回 |
|
以近似方式返回 |
|
计算沿 |
|
计算沿 |
|
逐元素反正弦:\(\mathrm{asin}(x)\)。 |
|
逐元素反双曲正弦:\(\mathrm{asinh}(x)\)。 |
|
逐元素反正切:\(\mathrm{atan}(x)\)。 |
|
两个变量的逐元素反正切:\(\mathrm{atan}({x \over y})\)。 |
|
逐元素反双曲正切:\(\mathrm{atanh}(x)\)。 |
|
批量矩阵乘法。 |
|
指数缩放的 0 阶修正贝塞尔函数:\(\mathrm{i0e}(x) = e^{-|x|} \mathrm{i0}(x)\) |
|
指数缩放的 1 阶修正贝塞尔函数:\(\mathrm{i1e}(x) = e^{-|x|} \mathrm{i1}(x)\) |
|
逐元素正则化不完全贝塔积分。 |
|
逐元素位转换。 |
|
逐元素 AND:\(x \wedge y\)。 |
|
逐元素 NOT:\(\neg x\)。 |
|
逐元素 OR:\(x \vee y\)。 |
|
逐元素异或:\(x \oplus y\)。 |
逐元素 popcount,计算每个元素中设置的位数。 |
|
|
广播一个数组,添加新的前导维度 |
|
包装 XLA 的 BroadcastInDim 运算符。 |
返回 shapes 的 NumPy 广播产生的形状。 |
|
|
添加前导维度 |
|
围绕 |
|
逐元素立方根:\(\sqrt[3]{x}\)。 |
|
逐元素向上取整:\(\left\lceil x \right\rceil\)。 |
|
逐元素钳位操作。 |
|
逐元素计算前导零的数量。 |
|
将数组的多个维度折叠成单个维度。 |
|
逐元素构造复数:\(x + jy\)。 |
|
沿维度连接一系列数组。 |
|
逐元素复共轭函数:\(\overline{x}\)。 |
|
conv_general_dilated的便捷包装器。 |
|
逐元素类型转换。 |
|
将卷积dimension_numbers转换为ConvDimensionNumbers。 |
|
通用的n维卷积运算符,具有可选的空洞卷积。 |
|
通用的n维非共享卷积运算符,具有可选的空洞卷积。 |
|
提取conv_general_dilated感受野内的patch。 |
|
用于计算 N 维卷积“转置”的便捷包装器。 |
|
conv_general_dilated的便捷包装器。 |
|
逐元素余弦:\(\mathrm{cos}(x)\)。 |
|
逐元素双曲余弦:\(\mathrm{cosh}(x)\)。 |
|
计算沿轴的累计对数和指数。 |
|
计算沿轴的累计最大值。 |
|
计算沿轴的累计最小值。 |
|
计算沿轴的累计乘积。 |
|
计算沿轴的累计和。 |
|
逐元素双伽马函数:\(\psi(x)\)。 |
|
逐元素除法:\(x \over y\)。 |
|
向量/向量、矩阵/向量和矩阵/矩阵乘法。 |
|
通用点积/缩并运算符。 |
|
用于执行整数索引的动态切片的便捷包装器。 |
|
包装 XLA 的 DynamicSlice 运算符。 |
|
应用于一个维度的 |
|
|
|
包装 XLA 的 DynamicUpdateSlice 运算符。 |
|
|
|
逐元素等于:\(x = y\)。 |
|
逐元素误差函数:\(\mathrm{erf}(x)\)。 |
|
逐元素互补误差函数:\(\mathrm{erfc}(x) = 1 - \mathrm{erf}(x)\)。 |
|
逐元素逆误差函数:\(\mathrm{erf}^{-1}(x)\)。 |
|
逐元素指数:\(e^x\)。 |
|
在数组中插入任意数量的大小为 1 的维度。 |
|
逐元素 \(e^{x} - 1\)。 |
|
|
|
逐元素向下取整:\(\left\lfloor x \right\rfloor\)。 |
|
返回一个用fill_value填充的shape数组。 |
|
基于示例数组x创建类似 np.full 的完整数组。 |
|
收集运算符。 |
|
逐元素大于或等于:\(x \geq y\)。 |
|
逐元素大于:\(x > y\)。 |
|
逐元素正则化不完全伽马函数。 |
|
逐元素互补正则化不完全伽马函数。 |
|
逐元素提取虚部:\(\mathrm{Im}(x)\)。 |
|
围绕 |
|
|
|
逐元素幂运算:\(x^y\),其中 \(y\) 是一个固定的整数。 |
|
封装 XLA 的 Iota 运算符。 |
|
逐元素 \(\mathrm{isfinite}\)。 |
|
逐元素小于或等于:\(x \leq y\)。 |
|
逐元素对数伽马函数:\(\mathrm{log}(\Gamma(x))\)。 |
|
逐元素自然对数:\(\mathrm{log}(x)\)。 |
|
逐元素 \(\mathrm{log}(1 + x)\)。 |
|
逐元素逻辑(sigmoid)函数:\(\frac{1}{1 + e^{-x}}\)。 |
|
逐元素小于:\(x < y\)。 |
|
逐元素最大值:\(\mathrm{max}(x, y)\)。 |
|
逐元素最小值:\(\mathrm{min}(x, y)\)。 |
|
逐元素乘法:\(x \times y\)。 |
|
逐元素不等于:\(x \neq y\)。 |
|
逐元素取反:\(-x\)。 |
|
返回 x1 之后,朝向 x2 方向的下一个可表示的值。 |
|
阻止编译器跨越屏障移动操作。 |
|
对数组应用低、高和/或内部填充。 |
|
分阶段输出特定于平台代码。 |
|
逐元素多伽马函数:\(\psi^{(m)}(x)\)。 |
逐元素 popcount,计算每个元素中设置的位数。 |
|
|
逐元素幂运算:\(x^y\)。 |
|
来自 Gamma(a, 1) 的样本的逐元素导数。 |
|
逐元素提取实部:\(\mathrm{Re}(x)\)。 |
|
逐元素倒数:\(1 \over x\)。 |
|
封装 XLA 的 Reduce 运算符。 |
|
封装 XLA 的 ReducePrecision 运算符。 |
|
封装 XLA 的 ReduceWindowWithGeneralPadding 运算符。 |
|
逐元素取余:\(x \bmod y\)。 |
|
封装 XLA 的 Reshape 运算符。 |
|
封装 XLA 的 Rev 运算符。 |
|
无状态 PRNG 位生成器。 |
|
有状态 PRNG 生成器。 |
|
逐元素四舍五入。 |
|
逐元素倒数平方根:\(1 \over \sqrt{x}\)。 |
|
散布更新运算符。 |
|
散布加法运算符。 |
|
散布应用运算符。 |
|
散布最大值运算符。 |
|
散布最小值运算符。 |
|
散布乘法运算符。 |
|
逐元素左移:\(x \ll y\)。 |
|
逐元素算术右移:\(x \gg y\)。 |
|
逐元素逻辑右移:\(x \gg y\)。 |
|
逐元素符号函数。 |
|
逐元素正弦函数:\(\mathrm{sin}(x)\)。 |
|
逐元素双曲正弦函数:\(\mathrm{sinh}(x)\)。 |
|
封装了 XLA 的 Slice 操作符。 |
|
围绕 |
|
封装了 XLA 的 Sort 操作符。 |
|
沿着 |
|
沿着 |
|
逐元素平方根:\(\sqrt{x}\)。 |
|
逐元素平方:\(x^2\)。 |
|
从数组中挤压任意数量的大小为 1 的维度。 |
|
逐元素减法:\(x - y\)。 |
|
逐元素正切:\(\mathrm{tan}(x)\)。 |
|
逐元素双曲正切:\(\mathrm{tanh}(x)\)。 |
|
返回 |
|
封装了 XLA 的 Transpose 操作符。 |
|
逐元素 Hurwitz zeta 函数:\(\zeta(x, q)\) |
控制流运算符#
|
使用关联的二元运算并行执行扫描。 |
|
有条件地应用 |
|
通过归约为 |
|
将函数映射到前导数组轴上。 |
|
在执行状态的同时,扫描前导数组轴上的函数。 |
|
基于布尔谓词在两个分支之间选择。 |
|
从多个案例中选择数组值。 |
|
应用由 |
|
当 |
自定义梯度运算符#
停止梯度计算。 |
|
|
使用隐式定义的梯度执行无矩阵线性求解。 |
|
可微地求解函数的根。 |
并行运算符#
|
跨所有副本收集 x 的值。 |
|
实现映射轴并映射不同的轴。 |
|
在 pmapped 轴 |
|
类似于 |
|
在 pmapped 轴 |
|
在 pmapped 轴 |
|
在 pmapped 轴 |
|
根据排列 |
|
jax.lax.ppermute 的便捷封装,带有备用排列编码 |
|
将 pmapped 轴 |
|
返回沿映射轴 |
线性代数运算符 (jax.lax.linalg)#
|
Cholesky 分解。 |
|
一般矩阵的特征分解。 |
|
厄米特矩阵的特征分解。 |
|
将一个方阵缩减为上Hessenberg形式。 |
|
带有部分选主的LU分解。 |
|
基本Householder反射器的乘积。 |
|
用于极分解的基于QR的动态加权Halley迭代。 |
|
QR分解。 |
|
|
|
奇异值分解。 |
|
三角解法。 |
|
将对称/厄米特矩阵缩减为三对角形式。 |
|
计算三对角线性系统的解。 |
参数类#
- jax.lax.ConvGeneralDilatedDimensionNumbers#
tuple
[str
,str
,str
] |ConvDimensionNumbers
|None
的别名
- class jax.lax.DotAlgorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count=1, rhs_component_count=1, num_primitive_operations=1, allow_imprecise_accumulation=False)[source]#
指定用于计算点积的算法。
当用于指定
dot()
,dot_general()
和其他点积函数的precision
输入时,此数据结构用于控制用于计算点积的算法的属性。此 API 控制计算所用的精度,并允许用户访问特定于硬件的加速。对这些算法的支持取决于平台,并且当编译计算时,使用不支持的算法将引发 Python 异常。在至少某些平台上已知支持的算法在
DotAlgorithmPreset
枚举中列出,并且这些是尝试此 API 的良好起点。“点算法”由以下参数指定
lhs_precision_type
和rhs_precision_type
,操作的 LHS 和 RHS 四舍五入到的数据类型。accumulation_type
用于累加的数据类型。lhs_component_count
、rhs_component_count
和num_primitive_operations
适用于将 LHS 和/或 RHS 分解为多个组件并在这些值上执行多个操作的算法,通常是为了模拟更高的精度。对于没有分解的算法,这些值应设置为1
。allow_imprecise_accumulation
用于指定是否允许在某些步骤中以较低的精度进行累加(例如,CUBLASLT_MATMUL_DESC_FAST_ACCUM
)。
StableHLO 规范 对于点操作不要求精度类型与输入或输出的存储类型相同,但是某些平台可能要求这些类型匹配。此外,如果指定,
dot_general()
的返回类型始终由输入算法的accumulation_type
参数定义。示例
使用 32 位浮点累加器累加两个 16 位浮点数
>>> algorithm = DotAlgorithm( ... lhs_precision_type=np.float16, ... rhs_precision_type=np.float16, ... accumulation_type=np.float32, ... ) >>> lhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) >>> rhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) >>> dot(lhs, rhs, precision=algorithm) array([ 1., 4., 9., 16.], dtype=float16)
或者,等效地,使用预设
>>> algorithm = DotAlgorithmPreset.F16_F16_F32 >>> dot(lhs, rhs, precision=algorithm) array([ 1., 4., 9., 16.], dtype=float16)
也可以按名称指定预设
>>> dot(lhs, rhs, precision="F16_F16_F32") array([ 1., 4., 9., 16.], dtype=float16)
preferred_element_type
参数可用于在不向下转换累积类型的情况下返回输出>>> dot(lhs, rhs, precision="F16_F16_F32", preferred_element_type=np.float32) array([ 1., 4., 9., 16.], dtype=float32)
- class jax.lax.DotAlgorithmPreset(value)[source]#
用于计算点积的已知算法的枚举。
此
Enum
提供了一组命名的DotAlgorithm
对象,已知至少在平台上受支持。有关这些算法行为的更多详细信息,请参阅DotAlgorithm
文档。在调用
dot()
、dot_general()
或大多数其他 JAX 点积函数时,可以通过传递此Enum
的成员或其名称作为字符串,并使用precision
参数,从此列表中选择一个算法。例如,用户可以使用此
Enum
直接指定预设>>> lhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) >>> rhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) >>> algorithm = DotAlgorithmPreset.F16_F16_F32 >>> dot(lhs, rhs, precision=algorithm) array([ 1., 4., 9., 16.], dtype=float16)
或者,等效地,可以按名称指定它们
>>> dot(lhs, rhs, precision="F16_F16_F32") array([ 1., 4., 9., 16.], dtype=float16)
预设的名称通常为
LHS_RHS_ACCUM
,其中LHS
和RHS
分别是lhs
和rhs
输入的元素类型,ACCUM
是累加器的元素类型。某些预设有额外的后缀,并且每个后缀的含义在下面记录。支持的预设是- DEFAULT = 1#
将根据输入和输出类型选择算法。
- ANY_F8_ANY_F8_F32 = 2#
接受任何 float8 输入类型,并累积到 float32 中。
- ANY_F8_ANY_F8_F32_FAST_ACCUM = 3#
类似于
ANY_F8_ANY_F8_F32
,但使用更快的累积,代价是精度较低。
- ANY_F8_ANY_F8_ANY = 4#
类似于
ANY_F8_ANY_F8_F32
,但累积类型由preferred_element_type
控制。
- ANY_F8_ANY_F8_ANY_FAST_ACCUM = 5#
类似于
ANY_F8_ANY_F8_F32_FAST_ACCUM
,但累积类型由preferred_element_type
控制。
- F16_F16_F16 = 6#
- F16_F16_F32 = 7#
- BF16_BF16_BF16 = 8#
- BF16_BF16_F32 = 9#
- BF16_BF16_F32_X3 = 10#
_X3
后缀表示该算法使用 3 次运算来模拟更高的精度。
- BF16_BF16_F32_X6 = 11#
类似于
BF16_BF16_F32_X3
,但使用 6 次运算而不是 3 次。
- TF32_TF32_F32 = 12#
- TF32_TF32_F32_X3 = 13#
_X3
后缀表示该算法使用 3 次运算来模拟更高的精度。
- F32_F32_F32 = 14#
- F64_F64_F64 = 15#
- class jax.lax.FftType(value)[源代码]#
描述要执行的 FFT 操作。
- FFT = 0#
正向复数到复数的 FFT。
- IFFT = 1#
反向复数到复数的 FFT。
- IRFFT = 3#
反向实数到复数的 FFT。
- RFFT = 2#
正向实数到复数的 FFT。
- class jax.lax.GatherDimensionNumbers(offset_dims, collapsed_slice_dims, start_index_map, operand_batching_dims=(), start_indices_batching_dims=())[源代码]#
描述 XLA 的 Gather 运算符的维度编号参数。有关维度编号含义的更多详细信息,请参阅 XLA 文档。
- 参数:
offset_dims (tuple[int, ...]) – gather 输出中偏移到从 operand 切片的数组中的维度集合。必须是一个按升序排列的整数元组,每个整数表示输出的维度编号。
collapsed_slice_dims (tuple[int, ...]) – operand 中 slice_sizes[i] == 1 且在 gather 输出中不应具有对应维度的维度 i 的集合。必须是一个按升序排列的整数元组。
start_index_map (tuple[int, ...]) – 对于 start_indices 中的每个维度,给出 operand 中要切片的对应维度。必须是大小等于 start_indices.shape[-1] 的整数元组。
operand_batching_dims (tuple[int, ...]) – operand 中 slice_sizes[i] == 1 且在 start_indices(在 start_indices_batching_dims 中的相同索引处)和 gather 输出中都应具有对应维度的批处理维度 i 的集合。必须是一个按升序排列的整数元组。
start_indices_batching_dims (tuple[int, ...]) – start_indices 中批量处理维度 i 的集合,这些维度应在 operand (在 operand_batching_dims 中对应的索引处)和 gather 的输出中都有相应的维度。必须是整数的元组(顺序根据与 operand_batching_dims 的对应关系确定)。
与 XLA 的 GatherDimensionNumbers 结构不同,index_vector_dim 是隐式的;始终存在索引向量维度,并且它必须始终是最后一个维度。要收集标量索引,请添加大小为 1 的尾随维度。
- class jax.lax.GatherScatterMode(value)[source]#
描述如何在 gather 或 scatter 中处理越界索引。
可能的值为
- CLIP
索引将被钳制到最近的有效值,即,使得要 gather 的整个窗口都在范围内。
- FILL_OR_DROP
如果要 gather 的窗口的任何部分超出范围,则返回的整个窗口(即使是原本在范围内的元素)都将填充一个常量。如果 scatter 的窗口的任何部分超出范围,则将丢弃整个窗口。
- PROMISE_IN_BOUNDS
用户承诺索引在范围内。不会执行额外的检查。实际上,在当前的 XLA 实现中,这意味着越界的 gather 将被钳制,但越界的 scatter 将被丢弃。如果索引超出范围,则梯度将不正确。
- class jax.lax.Precision(value)[source]#
用于 lax 矩阵乘法相关函数的精度枚举。
JAX 函数的设备相关 precision 参数通常控制加速器后端(即 TPU 和 GPU)上数组计算的速度和准确性之间的权衡。对 CPU 后端没有影响。这仅对 float32 计算有影响,并且不会影响输入/输出数据类型。成员是
- DEFAULT
最快的模式,但最不准确。在 TPU 上:以 bfloat16 执行 float32 计算。在 GPU 上:如果可用,则使用 tensorfloat32(例如,在 A100 和 H100 GPU 上),否则使用标准 float32(例如,在 V100 GPU 上)。别名:
'default'
,'fastest'
。- HIGH
较慢但更准确。在 TPU 上:以 3 次 bfloat16 传递执行 float32 计算。在 GPU 上:在可用时使用 tensorfloat32,否则使用 float32。别名:
'high'
。- HIGHEST
最慢但最准确。在 TPU 上:以 6 次 bfloat16 执行 float32 计算。别名:
'highest'
。在 GPU 上:使用 float32。
- jax.lax.PrecisionLike#
别名:
None
|str
|Precision
|tuple
[str
,str
] |tuple
[Precision
,Precision
] |DotAlgorithm
|DotAlgorithmPreset
- class jax.lax.RandomAlgorithm(value)[source]#
描述用于 rng_bit_generator 的 PRNG 算法。
- RNG_DEFAULT = 0#
平台的默认算法。
- RNG_THREE_FRY = 1#
Threefry-2x32 PRNG 算法。
- RNG_PHILOX = 2#
Philox-4x32 PRNG 算法。
- class jax.lax.RoundingMethod(value)[source]#
用于处理
jax.lax.round()
中间值(例如,0.5)的舍入策略。- AWAY_FROM_ZERO = 0#
将中间值四舍五入远离零(例如,0.5 -> 1,-0.5 -> -1)。
- TO_NEAREST_EVEN = 1#
将中间值四舍五入到最近的偶数整数。这也被称为“银行家舍入”(例如,0.5 -> 0,1.5 -> 2)。
- class jax.lax.ScatterDimensionNumbers(update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims, operand_batching_dims=(), scatter_indices_batching_dims=())[source]#
描述 XLA 的 Scatter 运算符的维度编号参数。 有关维度编号含义的更多详细信息,请参阅 XLA 文档。
- 参数:
update_window_dims (Sequence[int]) – updates 中作为窗口维度的一组维度。必须是按升序排列的整数元组,每个整数代表一个维度编号。
inserted_window_dims (Sequence[int]) – 必须插入到 updates 形状中的大小为 1 的窗口维度集。必须是按升序排列的整数元组,每个整数代表输出的维度编号。在 gather 的情况下,这些是 collapsed_slice_dims 的镜像。
scatter_dims_to_operand_dims (Sequence[int]) – 对于 scatter_indices 中的每个维度,给出 operand 中对应的维度。必须是大小等于 scatter_indices.shape[-1] 的整数序列。
operand_batching_dims (Sequence[int]) – operand 中应该在 scatter_indices (对应 scatter_indices_batching_dims 中相同索引的位置) 和 updates 中都有对应维度的批处理维度 i 的集合。 必须是一个按升序排列的整数元组。在 gather 的情况下,这些是 operand_batching_dims 的镜像。
scatter_indices_batching_dims (Sequence[int]) – scatter_indices 中应该在 operand (对应 operand_batching_dims 中相同索引的位置) 和 gather 的输出中都有对应维度的批处理维度 i 的集合。 必须是一个整数元组 (顺序根据与 input_batching_dims 的对应关系固定)。在 gather 的情况下,这些是 start_indices_batching_dims 的镜像。
与 XLA 的 ScatterDimensionNumbers 结构不同,index_vector_dim 是隐式的;始终存在一个索引向量维度,并且它必须始终是最后一个维度。要分散标量索引,请添加大小为 1 的尾随维度。