jax.lax
模块#
jax.lax
是一个底层原语操作库,它支撑了诸如 jax.numpy
之类的库。转换规则,例如 JVP 和批处理规则,通常被定义为对 jax.lax
原语的转换。
许多原语是围绕等效 XLA 操作的简单包装器,由 XLA 操作语义 文档描述。在少数情况下,JAX 与 XLA 存在差异,通常是为了确保操作集在 JVP 和转置规则的操作下是封闭的。
尽可能使用诸如 jax.numpy
的库,而不是直接使用 jax.lax
。 jax.numpy
的 API 遵循 NumPy,因此比 jax.lax
的 API 更稳定,且不太可能发生变化。
运算符#
|
逐元素绝对值:\(|x|\)。 |
|
逐元素反余弦:\(\mathrm{acos}(x)\)。 |
|
逐元素反双曲余弦:\(\mathrm{acosh}(x)\)。 |
|
逐元素加法:\(x + y\)。 |
|
合并一个或多个 XLA 令牌值。 |
|
以近似方式返回 |
|
以近似方式返回 |
|
计算沿 |
|
计算沿 |
|
逐元素反正弦:\(\mathrm{asin}(x)\)。 |
|
逐元素反双曲正弦:\(\mathrm{asinh}(x)\)。 |
|
逐元素反正切:\(\mathrm{atan}(x)\)。 |
|
逐元素两个变量的反正切:\(\mathrm{atan}({x \over y})\)。 |
|
逐元素反双曲正切:\(\mathrm{atanh}(x)\)。 |
|
批量矩阵乘法。 |
|
指数缩放的0阶修正贝塞尔函数:\(\mathrm{i0e}(x) = e^{-|x|} \mathrm{i0}(x)\) |
|
指数缩放的1阶修正贝塞尔函数:\(\mathrm{i1e}(x) = e^{-|x|} \mathrm{i1}(x)\) |
|
逐元素正则化不完全贝塔积分。 |
|
逐元素位强制转换。 |
|
逐元素与:\(x \wedge y\)。 |
|
逐元素非:\(\neg x\)。 |
|
逐元素或:\(x \vee y\)。 |
|
逐元素异或:\(x \oplus y\)。 |
逐元素位计数,计算每个元素中设置位的数量。 |
|
|
广播数组,添加新的前导维度 |
|
包装 XLA 的 BroadcastInDim 运算符。 |
返回 NumPy 广播 shapes 后得到的结果形状。 |
|
|
添加前导维度 |
|
围绕 |
|
逐元素立方根:\(\sqrt[3]{x}\)。 |
|
逐元素向上取整:\(\left\lceil x \right\rceil\)。 |
|
逐元素钳位。 |
|
逐元素前导零计数。 |
|
将数组的维度折叠成单个维度。 |
|
逐元素创建复数:\(x + jy\)。 |
|
沿着 dimension 连接一系列数组。 |
|
逐元素复共轭函数:\(\overline{x}\)。 |
|
围绕 conv_general_dilated 的便捷包装器。 |
|
逐元素强制转换。 |
|
将卷积 dimension_numbers 转换为 ConvDimensionNumbers。 |
|
一般的 n 维卷积运算符,可选扩张。 |
|
具有可选扩张的一般 n 维非共享卷积运算符。 |
|
提取受 conv_general_dilated 感受野影响的补丁。 |
|
计算 N 维卷积“转置”的便捷包装器。 |
|
围绕 conv_general_dilated 的便捷包装器。 |
|
逐元素余弦:\(\mathrm{cos}(x)\)。 |
|
逐元素双曲余弦:\(\mathrm{cosh}(x)\)。 |
|
沿 axis 计算累积 logsumexp。 |
|
沿 axis 计算累积最大值。 |
|
沿 axis 计算累积最小值。 |
|
沿 axis 计算累积积。 |
|
沿 axis 计算累积和。 |
|
逐元素 digamma 函数:\(\psi(x)\)。 |
|
逐元素除法:\(x \over y\)。 |
|
向量/向量、矩阵/向量和矩阵/矩阵乘法。 |
|
通用点积/收缩运算符。 |
|
围绕 |
|
包装 XLA 的 DynamicSlice 运算符。 |
|
围绕 |
|
围绕 |
|
包装 XLA 的 DynamicUpdateSlice 运算符。 |
|
围绕 |
|
逐元素等于:\(x = y\)。 |
|
逐元素误差函数:\(\mathrm{erf}(x)\)。 |
|
逐元素互补误差函数:\(\mathrm{erfc}(x) = 1 - \mathrm{erf}(x)\)。 |
|
逐元素误差函数的反函数:\(\mathrm{erf}^{-1}(x)\)。 |
|
逐元素指数函数:\(e^x\)。 |
|
在数组中插入任意数量的大小为 1 的维度。 |
|
逐元素计算 \(e^{x} - 1\)。 |
|
|
|
逐元素向下取整:\(\left\lfloor x \right\rfloor\)。 |
|
返回一个形状为 shape、填充值为 fill_value 的数组。 |
|
基于示例数组 x 创建一个类似于 np.full 的填充数组。 |
|
收集运算符。 |
|
逐元素大于等于:\(x \geq y\)。 |
|
逐元素大于:\(x > y\)。 |
|
逐元素正则化不完全伽马函数。 |
|
逐元素互补正则化不完全伽马函数。 |
|
逐元素提取虚部:\(\mathrm{Im}(x)\)。 |
|
围绕 |
|
|
|
逐元素幂运算:\(x^y\),其中 \(y\) 是一个固定的整数。 |
|
包装 XLA 的 Iota 运算符。 |
|
逐元素 \(\mathrm{isfinite}\)。 |
|
逐元素小于等于:\(x \leq y\)。 |
|
逐元素对数伽马函数:\(\mathrm{log}(\Gamma(x))\)。 |
|
逐元素自然对数:\(\mathrm{log}(x)\)。 |
|
逐元素计算 \(\mathrm{log}(1 + x)\)。 |
|
逐元素逻辑斯蒂函数(sigmoid 函数):\(\frac{1}{1 + e^{-x}}\)。 |
|
逐元素小于:\(x < y\)。 |
|
逐元素最大值:\(\mathrm{max}(x, y)\) |
|
逐元素最小值:\(\mathrm{min}(x, y)\) |
|
逐元素乘法:\(x \times y\)。 |
|
逐元素不等于:\(x \neq y\)。 |
|
逐元素取负:\(-x\)。 |
|
返回在x2方向上x1之后的下一个可表示值。 |
|
阻止编译器跨越屏障移动操作。 |
|
对数组应用低、高和/或内部填充。 |
|
分阶段执行特定于平台的代码。 |
|
逐元素多伽马函数:\(\psi^{(m)}(x)\)。 |
逐元素位计数,计算每个元素中设置位的数量。 |
|
|
逐元素幂运算:\(x^y\)。 |
|
来自Gamma(a, 1)的样本的逐元素导数。 |
|
逐元素提取实部:\(\mathrm{Re}(x)\)。 |
|
逐元素倒数:\(1 \over x\)。 |
|
封装 XLA 的 Reduce 运算符。 |
|
封装 XLA 的 ReducePrecision 运算符。 |
|
|
|
逐元素取余:\(x \bmod y\)。 |
|
封装 XLA 的 Reshape 运算符。 |
|
封装 XLA 的 Rev 运算符。 |
|
无状态 PRNG 位生成器。 |
|
有状态 PRNG 生成器。 |
|
逐元素舍入。 |
|
逐元素倒数平方根:\(1 \over \sqrt{x}\)。 |
|
散射更新运算符。 |
|
散射加法运算符。 |
|
散射应用运算符。 |
|
散射最大值运算符。 |
|
散射最小值运算符。 |
|
散射乘法运算符。 |
|
逐元素左移:\(x \ll y\)。 |
|
逐元素算术右移:\(x \gg y\)。 |
|
逐元素逻辑右移:\(x \gg y\)。 |
|
逐元素符号函数。 |
|
逐元素正弦函数:\(\mathrm{sin}(x)\)。 |
|
逐元素双曲正弦函数:\(\mathrm{sinh}(x)\)。 |
|
封装 XLA 的 Slice 运算符。 |
|
围绕 |
|
封装 XLA 的 Sort 运算符。 |
|
沿着 |
|
逐元素平方根:\(\sqrt{x}\)。 |
|
逐元素平方:\(x^2\)。 |
|
从数组中压缩任意数量的大小为 1 的维度。 |
|
逐元素减法:\(x - y\)。 |
|
逐元素正切函数:\(\mathrm{tan}(x)\)。 |
|
逐元素双曲正切函数:\(\mathrm{tanh}(x)\)。 |
|
返回 |
|
封装 XLA 的 Transpose 运算符。 |
|
逐元素 Hurwitz zeta 函数:\(\zeta(x, q)\) |
控制流运算符#
|
并行执行具有结合二元运算的扫描。 |
|
有条件地应用 |
|
从 |
|
在数组的前导轴上映射函数。 |
|
在数组的前导轴上扫描函数,同时传递状态。 |
|
根据布尔谓词在两个分支之间进行选择。 |
|
从多个案例中选择数组值。 |
|
根据 |
|
当 |
自定义梯度算子#
停止梯度计算。 |
|
|
使用隐式定义的梯度执行无矩阵线性求解。 |
|
可微地求解函数的根。 |
并行算子#
|
收集所有副本中 x 的值。 |
|
具体化映射轴并映射不同的轴。 |
|
在 pmapped 轴 |
|
类似于 |
|
在 pmapped 轴 |
|
在 pmapped 轴 |
|
在 pmapped 轴 |
|
根据排列 |
|
jax.lax.ppermute 的便捷包装器,使用备用排列编码。 |
|
将 pmapped 轴 |
|
返回映射轴 |
线性代数算子 (jax.lax.linalg)#
|
Cholesky 分解。 |
|
一般矩阵的特征分解。 |
|
厄米特矩阵的特征分解。 |
|
将方阵简化为上 Hessenberg 形式。 |
|
带部分主元的 LU 分解。 |
|
基本 Householder 反射器的乘积。 |
|
基于 QR 的动态加权 Halley 迭代,用于极分解。 |
|
QR 分解。 |
|
|
|
奇异值分解。 |
|
三角求解。 |
|
将对称/厄米特矩阵简化为三对角形式。 |
|
计算三对角线性系统的解。 |
参数类#
- class jax.lax.GatherDimensionNumbers(offset_dims, collapsed_slice_dims, start_index_map)[source]#
描述 XLA 的 Gather 运算符 的维度编号参数。有关维度编号含义的更多详细信息,请参阅 XLA 文档。
- 参数:
offset_dims (tuple[int, ...]) – gather 输出中偏移到从 operand 切片得到的数组的一组维度。必须是按升序排列的整数元组,每个整数表示输出的维度编号。
collapsed_slice_dims (tuple[int, ...]) – operand 中满足 slice_sizes[i] == 1 且不应在 gather 输出中具有对应维度的一组维度 i。必须是按升序排列的整数元组。
start_index_map (tuple[int, ...]) – 对于 start_indices 中的每个维度,给出要从中切片的 operand 中的对应维度。必须是大小等于 start_indices.shape[-1] 的整数元组。
与 XLA 的 GatherDimensionNumbers 结构不同,index_vector_dim 是隐式的;始终存在索引向量维度,并且它必须始终是最后一个维度。要收集标量索引,请添加一个大小为 1 的尾随维度。
- class jax.lax.GatherScatterMode(value)[source]#
描述如何在 gather 或 scatter 中处理越界索引。
可能的值是
- CLIP
索引将被钳位到最近的范围内值,即,使得要收集的整个窗口都在范围内。
- FILL_OR_DROP
如果收集的窗口的任何部分超出范围,则返回的整个窗口(即使其他元素在范围内)都将填充一个常数。如果散布的窗口的任何部分超出范围,则整个窗口将被丢弃。
- PROMISE_IN_BOUNDS
用户保证索引在范围内。不会执行其他检查。实际上,使用当前的 XLA 实现,这意味着越界 gather 将被钳位,但越界 scatter 将被丢弃。如果索引超出范围,则梯度将不正确。
- class jax.lax.Precision(value)[source]#
lax 矩阵乘法相关函数的精度枚举。
JAX 函数中与设备相关的 precision 参数通常控制加速器后端(即 TPU 和 GPU)上数组计算的速度和精度之间的权衡。对 CPU 后端没有影响。这仅对 float32 计算有效,并且不影响输入/输出数据类型。成员是
- DEFAULT
最快的模式,但精度最低。在 TPU 上:以 bfloat16 执行 float32 计算。在 GPU 上:如果可用,则使用 tensorfloat32(例如在 A100 和 H100 GPU 上),否则使用标准 float32(例如在 V100 GPU 上)。别名:
'default'
、'fastest'
。- HIGH
速度较慢但精度较高。在 TPU 上:在 3 个 bfloat16 传递中执行 float32 计算。在 GPU 上:在可用情况下使用 tensorfloat32,否则使用 float32。别名:
'high'
..- HIGHEST
速度最慢但精度最高。在 TPU 上:在 6 个 bfloat16 中执行 float32 计算。别名:
'highest'
。在 GPU 上:使用 float32。
- class jax.lax.ScatterDimensionNumbers(update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims)[source]#
描述 XLA 的 Scatter 运算符 的维度编号参数。有关维度编号含义的更多详细信息,请参阅 XLA 文档。
- 参数:
update_window_dims (Sequence[int]) – updates 中的一组窗口维度。必须是按升序排列的整数元组,每个整数表示维度编号。
inserted_window_dims (Sequence[int]) – 必须插入到 updates 形状中的大小为 1 的窗口维度集。必须是按升序排列的整数元组,每个整数表示输出的维度编号。这些是 gather 情况下 collapsed_slice_dims 的镜像。
scatter_dims_to_operand_dims (Sequence[int]) – 对于 scatter_indices 中的每个维度,给出 operand 中的对应维度。必须是大小等于 scatter_indices.shape[-1] 的整数序列。
与 XLA 的 ScatterDimensionNumbers 结构不同,index_vector_dim 是隐式的;始终存在索引向量维度,并且它必须始终是最后一个维度。要散布标量索引,请添加一个大小为 1 的尾随维度。