jax.lax
模块#
jax.lax
是一个原始操作库,它是 jax.numpy
等库的基础。诸如 JVP 和批处理规则之类的转换规则通常被定义为对 jax.lax
原始操作的转换。
许多原始操作都是对等效的 XLA 操作的简单封装,这些操作在 XLA 操作语义 文档中有描述。在少数情况下,JAX 会偏离 XLA,通常是为了确保操作集在 JVP 和转置规则的操作下是闭合的。
在可能的情况下,请优先使用 jax.numpy
等库,而不是直接使用 jax.lax
。jax.numpy
API 遵循 NumPy,因此比 jax.lax
API 更稳定,更不容易更改。
运算符#
|
逐元素绝对值:\(|x|\)。 |
|
逐元素反余弦:\(\mathrm{acos}(x)\)。 |
|
逐元素反双曲余弦:\(\mathrm{acosh}(x)\)。 |
|
逐元素加法:\(x + y\)。 |
|
合并一个或多个 XLA 令牌值。 |
|
以近似方式返回 |
|
以近似方式返回 |
|
计算沿 |
|
计算沿 |
|
逐元素反正弦:\(\mathrm{asin}(x)\)。 |
|
逐元素反双曲正弦:\(\mathrm{asinh}(x)\)。 |
|
逐元素反正切:\(\mathrm{atan}(x)\)。 |
|
两个变量的逐元素反正切:\(\mathrm{atan}({x \over y})\)。 |
|
逐元素反双曲正切:\(\mathrm{atanh}(x)\)。 |
|
批量矩阵乘法。 |
|
指数缩放的 0 阶修正贝塞尔函数:\(\mathrm{i0e}(x) = e^{-|x|} \mathrm{i0}(x)\) |
|
指数缩放的 1 阶修正贝塞尔函数:\(\mathrm{i1e}(x) = e^{-|x|} \mathrm{i1}(x)\) |
|
逐元素正则化不完全贝塔积分。 |
|
逐元素位转换。 |
|
逐元素与:\(x \wedge y\)。 |
|
逐元素非:\(\neg x\)。 |
|
逐元素或:\(x \vee y\)。 |
|
逐元素异或:\(x \oplus y\)。 |
逐元素 popcount,计算每个元素中设置的位数。 |
|
|
广播数组,添加新的前导维度。 |
|
封装 XLA 的 BroadcastInDim 运算符。 |
返回对 shapes 进行 NumPy 广播后得到的形状。 |
|
|
添加前导维度 |
|
围绕 |
|
逐元素立方根:\(\sqrt[3]{x}\)。 |
|
逐元素向上取整:\(\left\lceil x \right\rceil\)。 |
|
逐元素钳制。 |
|
逐元素前导零计数。 |
|
将数组的维度折叠为单个维度。 |
|
按元素创建复数:\(x + jy\)。 |
|
沿 dimension 连接数组序列。 |
|
按元素取复共轭函数:\(\overline{x}\)。 |
|
conv_general_dilated 的便捷包装器。 |
|
按元素进行类型转换。 |
|
将卷积 dimension_numbers 转换为 ConvDimensionNumbers。 |
|
通用的 n 维卷积运算符,带有可选的膨胀。 |
|
通用的 n 维非共享卷积运算符,带有可选的膨胀。 |
|
提取 conv_general_dilated 的感受野的主题补丁。 |
|
用于计算 N 维卷积“转置”的便捷包装器。 |
|
conv_general_dilated 的便捷包装器。 |
|
按元素计算余弦值:\(\mathrm{cos}(x)\)。 |
|
按元素计算双曲余弦值:\(\mathrm{cosh}(x)\)。 |
|
沿 axis 计算累积 logsumexp。 |
|
沿 axis 计算累积最大值。 |
|
沿 axis 计算累积最小值。 |
|
沿 axis 计算累积乘积。 |
|
沿 axis 计算累积和。 |
|
按元素计算双伽玛函数:\(\psi(x)\)。 |
|
按元素计算除法:\(x \over y\)。 |
|
向量/向量、矩阵/向量和矩阵/矩阵乘法。 |
|
通用的点积/收缩运算符。 |
|
围绕 dynamic_slice 的便捷包装器,用于执行整数索引。 |
|
包装 XLA 的 DynamicSlice 运算符。 |
|
围绕应用于一个维度的 |
|
围绕 |
|
包装 XLA 的 DynamicUpdateSlice 运算符。 |
|
围绕 |
|
按元素比较是否相等:\(x = y\)。 |
|
按元素计算误差函数:\(\mathrm{erf}(x)\)。 |
|
按元素计算互补误差函数:\(\mathrm{erfc}(x) = 1 - \mathrm{erf}(x)\)。 |
|
按元素计算逆误差函数:\(\mathrm{erf}^{-1}(x)\)。 |
|
按元素计算指数:\(e^x\)。 |
|
将任意数量的大小为 1 的维度插入到数组中。 |
|
按元素计算 \(e^{x} - 1\)。 |
|
|
|
按元素计算下取整:\(\left\lfloor x \right\rfloor\)。 |
|
返回一个用 fill_value 填充的 shape 数组。 |
|
基于示例数组 x 创建类似 np.full 的完整数组。 |
|
收集运算符。 |
|
按元素比较是否大于或等于:\(x \geq y\)。 |
|
按元素比较是否大于:\(x > y\)。 |
|
按元素计算正则不完全伽玛函数。 |
|
按元素计算互补正则不完全伽玛函数。 |
|
按元素提取虚部:\(\mathrm{Im}(x)\)。 |
|
围绕 |
|
|
|
按元素计算幂:\(x^y\),其中 \(y\) 是一个固定的整数。 |
|
包装 XLA 的 Iota 运算符。 |
|
逐元素 \(\mathrm{isfinite}\)。 |
|
逐元素小于等于:\(x \leq y\)。 |
|
逐元素对数伽玛函数:\(\mathrm{log}(\Gamma(x))\)。 |
|
逐元素自然对数:\(\mathrm{log}(x)\)。 |
|
逐元素 \(\mathrm{log}(1 + x)\)。 |
|
逐元素 logistic (sigmoid) 函数:\(\frac{1}{1 + e^{-x}}\)。 |
|
逐元素小于:\(x < y\)。 |
|
逐元素最大值:\(\mathrm{max}(x, y)\) |
|
逐元素最小值:\(\mathrm{min}(x, y)\) |
|
逐元素乘法:\(x \times y\)。 |
|
逐元素不等于:\(x \neq y\)。 |
|
逐元素取反:\(-x\)。 |
|
返回 x1 沿 x2 方向的下一个可表示的值。 |
|
阻止编译器跨越障碍移动操作。 |
|
对数组应用低、高和/或内部填充。 |
|
分阶段处理特定于平台的代码。 |
|
逐元素多伽玛函数:\(\psi^{(m)}(x)\)。 |
逐元素 popcount,计算每个元素中设置的位数。 |
|
|
逐元素幂:\(x^y\)。 |
|
从 Gamma(a, 1) 中采样的元素的导数。 |
|
逐元素提取实部:\(\mathrm{Re}(x)\)。 |
|
逐元素倒数:\(1 \over x\)。 |
|
包装 XLA 的 Reduce 操作符。 |
|
包装 XLA 的 ReducePrecision 操作符。 |
|
|
|
逐元素余数:\(x \bmod y\)。 |
|
包装 XLA 的 Reshape 操作符。 |
|
包装 XLA 的 Rev 操作符。 |
|
无状态 PRNG 位生成器。 |
|
有状态 PRNG 生成器。 |
|
逐元素四舍五入。 |
|
逐元素倒数平方根:\(1 \over \sqrt{x}\)。 |
|
分散更新操作符。 |
|
分散相加操作符。 |
|
分散应用操作符。 |
|
分散最大值操作符。 |
|
分散最小值操作符。 |
|
分散乘法操作符。 |
|
逐元素左移:\(x \ll y\)。 |
|
逐元素算术右移:\(x \gg y\)。 |
|
逐元素逻辑右移:\(x \gg y\)。 |
|
逐元素符号函数。 |
|
逐元素正弦:\(\mathrm{sin}(x)\)。 |
|
逐元素双曲正弦:\(\mathrm{sinh}(x)\)。 |
|
包装 XLA 的 Slice 操作符。 |
|
围绕 |
|
包装 XLA 的 Sort 操作符。 |
|
沿 |
|
沿 |
|
逐元素平方根:\(\sqrt{x}\)。 |
|
逐元素平方:\(x^2\)。 |
|
从数组中挤压任意数量的大小为 1 的维度。 |
|
逐元素减法:\(x - y\)。 |
|
逐元素正切:\(\mathrm{tan}(x)\)。 |
|
逐元素双曲正切函数:\(\mathrm{tanh}(x)\)。 |
|
返回 |
|
包装 XLA 的 Transpose 操作符。 |
|
逐元素 Hurwitz zeta 函数:\(\zeta(x, q)\) |
控制流操作符#
|
并行执行具有结合律的二元操作的扫描。 |
|
有条件地应用 |
|
通过归约为 |
|
将函数映射到前导数组轴上。 |
|
在携带状态的同时,将函数扫描到前导数组轴上。 |
|
基于布尔谓词在两个分支之间进行选择。 |
|
从多个情况中选择数组值。 |
|
应用由 |
|
当 |
自定义梯度操作符#
停止梯度计算。 |
|
|
使用隐式定义的梯度执行无矩阵的线性求解。 |
|
可微分地求解函数的根。 |
并行操作符#
|
在所有副本中收集 x 的值。 |
|
将映射的轴具体化并映射不同的轴。 |
|
在 pmapped 轴 |
|
类似于 |
|
在 pmapped 轴 |
|
在 pmapped 轴 |
|
在 pmapped 轴 |
|
根据置换 |
|
jax.lax.ppermute 的便捷包装器,带有备用置换编码 |
|
将 pmapped 轴 |
|
返回沿映射轴 |
线性代数操作符 (jax.lax.linalg)#
|
Cholesky 分解。 |
|
一般矩阵的特征分解。 |
|
厄米矩阵的特征分解。 |
|
将方阵简化为上Hessenberg形式。 |
|
使用部分旋转的 LU 分解。 |
|
基本 Householder 反射器的乘积。 |
|
用于极分解的基于 QR 的动态加权哈雷迭代。 |
|
QR 分解。 |
|
|
|
奇异值分解。 |
|
三角求解。 |
|
将对称/厄米矩阵简化为三对角形式。 |
|
计算三对角线性系统的解。 |
参数类#
- class jax.lax.DotAlgorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count=1, rhs_component_count=1, num_primitive_operations=1, allow_imprecise_accumulation=False)[source]#
指定用于计算点积的算法。
当用于指定
precision
输入给dot()
,dot_general()
和其他点积函数时,此数据结构用于控制用于计算点积的算法的属性。此 API 控制计算所用的精度,并允许用户访问特定于硬件的加速。对这些算法的支持取决于平台,当编译计算时,使用不支持的算法将引发 Python 异常。在至少某些平台上已知支持的算法列在
DotAlgorithmPreset
枚举中,这些是尝试此 API 的一个很好的起点。“点积算法”由以下参数指定:
lhs_precision_type
和rhs_precision_type
,即操作的左侧 (LHS) 和右侧 (RHS) 被四舍五入到的数据类型。accumulation_type
,用于累加的数据类型。lhs_component_count
,rhs_component_count
和num_primitive_operations
适用于将 LHS 和/或 RHS 分解为多个组件并在这些值上执行多个操作的算法,通常是为了模拟更高的精度。对于没有分解的算法,这些值应设置为1
。allow_imprecise_accumulation
指定是否允许在某些步骤中以较低精度累加(例如CUBLASLT_MATMUL_DESC_FAST_ACCUM
)。
dot 操作的 StableHLO 规范 不要求精度类型与输入或输出的存储类型相同,但某些平台可能要求这些类型匹配。此外,
dot_general()
的返回类型始终由输入算法的accumulation_type
参数定义(如果指定)。示例
使用 32 位浮点累加器累加两个 16 位浮点数
>>> algorithm = DotAlgorithm( ... lhs_precision_type=np.float16, ... rhs_precision_type=np.float16, ... accumulation_type=np.float32, ... ) >>> lhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) >>> rhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) >>> dot(lhs, rhs, precision=algorithm) array([ 1., 4., 9., 16.], dtype=float16)
或者,等效地,使用预设
>>> algorithm = DotAlgorithmPreset.F16_F16_F32 >>> dot(lhs, rhs, precision=algorithm) array([ 1., 4., 9., 16.], dtype=float16)
预设也可以按名称指定
>>> dot(lhs, rhs, precision="F16_F16_F32") array([ 1., 4., 9., 16.], dtype=float16)
preferred_element_type
参数可用于返回输出而不向下转换累加类型>>> dot(lhs, rhs, precision="F16_F16_F32", preferred_element_type=np.float32) array([ 1., 4., 9., 16.], dtype=float32)
- class jax.lax.DotAlgorithmPreset(value)[source]#
用于计算点积的已知算法的枚举。
此
Enum
提供了一组命名的DotAlgorithm
对象,已知这些对象在至少一个平台上受支持。有关这些算法行为的更多详细信息,请参阅DotAlgorithm
文档。在调用
dot()
,dot_general()
或大多数其他 JAX 点积函数时,可以通过传递此Enum
的成员或其名称作为字符串使用precision
参数来从此列表中选择算法。例如,用户可以直接使用此
Enum
指定预设>>> lhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) >>> rhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) >>> algorithm = DotAlgorithmPreset.F16_F16_F32 >>> dot(lhs, rhs, precision=algorithm) array([ 1., 4., 9., 16.], dtype=float16)
或者,等效地,它们可以按名称指定
>>> dot(lhs, rhs, precision="F16_F16_F32") array([ 1., 4., 9., 16.], dtype=float16)
预设的名称通常为
LHS_RHS_ACCUM
,其中LHS
和RHS
分别是lhs
和rhs
输入的元素类型,而ACCUM
是累加器的元素类型。某些预设具有额外的后缀,每个后缀的含义如下所述。支持的预设包括:- DEFAULT = 1#
将根据输入和输出类型选择算法。
- ANY_F8_ANY_F8_F32 = 2#
接受任何 float8 输入类型,并累加到 float32。
- ANY_F8_ANY_F8_F32_FAST_ACCUM = 3#
与
ANY_F8_ANY_F8_F32
类似,但使用更快的累加,代价是精度较低。
- ANY_F8_ANY_F8_ANY = 4#
与
ANY_F8_ANY_F8_F32
类似,但累加类型由preferred_element_type
控制。
- ANY_F8_ANY_F8_ANY_FAST_ACCUM = 5#
与
ANY_F8_ANY_F8_F32_FAST_ACCUM
类似,但累加类型由preferred_element_type
控制。
- F16_F16_F16 = 6#
- F16_F16_F32 = 7#
- BF16_BF16_BF16 = 8#
- BF16_BF16_F32 = 9#
- BF16_BF16_F32_X3 = 10#
后缀
_X3
表示该算法使用 3 个操作来模拟更高的精度。
- BF16_BF16_F32_X6 = 11#
类似于
BF16_BF16_F32_X3
,但使用 6 个操作而不是 3 个。
- TF32_TF32_F32 = 12#
- TF32_TF32_F32_X3 = 13#
后缀
_X3
表示该算法使用 3 个操作来模拟更高的精度。
- F32_F32_F32 = 14#
- F64_F64_F64 = 15#
- class jax.lax.FftType(value)[源代码]#
描述要执行的 FFT 操作。
- FFT = 0#
正向复数到复数 FFT。
- IFFT = 1#
反向复数到复数 FFT。
- IRFFT = 3#
反向实数到复数 FFT。
- RFFT = 2#
正向实数到复数 FFT。
- class jax.lax.GatherDimensionNumbers(offset_dims, collapsed_slice_dims, start_index_map, operand_batching_dims=(), start_indices_batching_dims=())[源代码]#
描述 XLA 的 Gather 运算符的维度编号参数。有关维度编号的含义的更多详细信息,请参阅 XLA 文档。
- 参数:
offset_dims (tuple[int, ...]) – gather 输出中偏移到从 operand 切片数组中的维度集合。必须是按升序排列的整数元组,每个整数代表输出的维度编号。
collapsed_slice_dims (tuple[int, ...]) – operand 中 slice_sizes[i] == 1 且在 gather 输出中不应具有相应维度的维度 i 的集合。必须是按升序排列的整数元组。
start_index_map (tuple[int, ...]) – 对于 start_indices 中的每个维度,给出 operand 中要切片的相应维度。必须是大小等于 start_indices.shape[-1] 的整数元组。
operand_batching_dims (tuple[int, ...]) – operand 中具有 slice_sizes[i] == 1 并且在 start_indices 中(在 start_indices_batching_dims 中的相同索引处)和 gather 输出中都应具有相应维度的批处理维度 i 的集合。必须是按升序排列的整数元组。
start_indices_batching_dims (tuple[int, ...]) – start_indices 中应在 operand 中(在 operand_batching_dims 中的相同索引处)和 gather 输出中都具有相应维度的批处理维度 i 的集合。必须是整数元组(顺序根据与 operand_batching_dims 的对应关系固定)。
与 XLA 的 GatherDimensionNumbers 结构不同,index_vector_dim 是隐式的;始终存在一个索引向量维度,并且它必须始终是最后一个维度。要收集标量索引,请添加大小为 1 的尾随维度。
- class jax.lax.GatherScatterMode(value)[源代码]#
描述如何在 gather 或 scatter 中处理越界索引。
可能的值为
- CLIP
索引将被钳制到最近的范围内值,即,使要 gather 的整个窗口都在范围内。
- FILL_OR_DROP
如果 gather 窗口的任何部分超出范围,则返回的整个窗口(即使是那些在范围内,也会用常量填充。如果 scattered 窗口的任何部分超出范围,则将丢弃整个窗口。
- PROMISE_IN_BOUNDS
用户承诺索引在范围内。不会执行其他检查。实际上,使用当前的 XLA 实现,这意味着越界 gather 将被钳制,但越界 scatter 将被丢弃。如果索引越界,则梯度将不正确。
- class jax.lax.Precision(value)[源代码]#
用于 lax 矩阵乘法相关函数的精度枚举。
JAX 函数中与设备相关的 precision 参数通常控制加速器后端(即 TPU 和 GPU)上数组计算的速度和精度之间的权衡。对 CPU 后端没有影响。这仅对 float32 计算有效,并且不影响输入/输出数据类型。成员包括:
- DEFAULT
最快模式,但精度最低。在 TPU 上:以 bfloat16 执行 float32 计算。在 GPU 上:如果可用,则使用 tensorfloat32(例如,在 A100 和 H100 GPU 上),否则使用标准 float32(例如,在 V100 GPU 上)。别名:
'default'
,'fastest'
。- HIGH
较慢但更精确。在 TPU 上:以 3 次 bfloat16 传递执行 float32 计算。在 GPU 上:在可用时使用 tensorfloat32,否则使用 float32。别名:
'high'
。- HIGHEST
最慢但最精确。在 TPU 上:以 6 次 bfloat16 传递执行 float32 计算。别名:
'highest'
。在 GPU 上:使用 float32。
- jax.lax.PrecisionLike#
None
|str
|Precision
|tuple
[str
,str
] |tuple
[Precision
,Precision
] |DotAlgorithm
|DotAlgorithmPreset
的别名
- class jax.lax.RandomAlgorithm(value)[源代码]#
描述用于 rng_bit_generator 的 PRNG 算法。
- RNG_DEFAULT = 0#
平台的默认算法。
- RNG_THREE_FRY = 1#
Threefry-2x32 PRNG 算法。
- RNG_PHILOX = 2#
Philox-4x32 PRNG 算法。
- class jax.lax.RoundingMethod(value)[源代码]#
在
jax.lax.round()
中处理中间值(例如 0.5)的舍入策略。- AWAY_FROM_ZERO = 0#
将中间值四舍五入到远离零的值(例如,0.5 -> 1, -0.5 -> -1)。
- TO_NEAREST_EVEN = 1#
将中间值四舍五入到最接近的偶数整数。这也称为“银行家舍入”(例如,0.5 -> 0, 1.5 -> 2)。
- class jax.lax.ScatterDimensionNumbers(update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims, operand_batching_dims=(), scatter_indices_batching_dims=())[源代码]#
描述 XLA 的 Scatter 运算符的维度编号参数。有关维度编号含义的更多详细信息,请参阅 XLA 文档。
- 参数:
update_window_dims (Sequence[int]) – updates 中作为窗口维度的一组维度。必须是按升序排列的整数元组,每个整数代表一个维度编号。
inserted_window_dims (Sequence[int]) – 必须插入到 updates 形状中的大小为 1 的一组窗口维度。必须是按升序排列的整数元组,每个整数代表输出的维度编号。在 gather 的情况下,它们是 collapsed_slice_dims 的镜像。
scatter_dims_to_operand_dims (Sequence[int]) – 对于 scatter_indices 中的每个维度,给出 operand 中的相应维度。必须是大小等于 scatter_indices.shape[-1] 的整数序列。
operand_batching_dims (Sequence[int]) – operand 中应该在 scatter_indices(在 scatter_indices_batching_dims 中的相同索引处)和 updates 中都有相应维度的一组批处理维度 i。必须是按升序排列的整数元组。在 gather 的情况下,它们是 operand_batching_dims 的镜像。
scatter_indices_batching_dims (Sequence[int]) – scatter_indices 中应该在 operand(在 operand_batching_dims 中的相同索引处)和 gather 的输出中都有相应维度的一组批处理维度 i。必须是整数元组(顺序根据与 input_batching_dims 的对应关系确定)。在 gather 的情况下,它们是 start_indices_batching_dims 的镜像。
与 XLA 的 ScatterDimensionNumbers 结构不同,index_vector_dim 是隐式的;始终存在索引向量维度,并且它必须始终是最后一个维度。要分散标量索引,请添加大小为 1 的尾随维度。