jax.numpy
模块#
使用 jax.lax
中的原语实现 NumPy API。
虽然 JAX 尝试尽可能紧密地遵循 NumPy API,但有时 JAX 无法完全遵循 NumPy。
值得注意的是,由于 JAX 数组是不可变的,因此 NumPy 中会原地修改数组的 API 无法在 JAX 中实现。但是,JAX 通常能够提供一种纯函数式的替代 API。例如,JAX 提供了纯索引更新函数
x.at[i].set(y)
(参见ndarray.at
)来替代原地数组更新 (x[i] = y
)。类似地,一些 NumPy 函数在可能的情况下会返回数组的视图(例如
transpose()
和reshape()
)。JAX 版本的此类函数将返回副本,但当使用jax.jit()
编译一系列操作时,XLA 通常会优化掉这些副本。NumPy 在将值提升到
float64
类型方面非常激进。JAX 在类型提升方面有时不那么激进(请参阅 类型提升语义)。一些 NumPy 例程具有数据相关的输出形状(例如
unique()
和nonzero()
)。由于 XLA 编译器要求数组形状在编译时已知,因此此类操作与 JIT 不兼容。因此,JAX 为此类函数添加了一个可选的size
参数,可以在静态情况下指定该参数以便与 JIT 一起使用。
几乎所有适用的 NumPy 函数都在 jax.numpy
命名空间中实现;它们在下面列出。
用于索引更新功能的辅助属性。 |
|
|
|
|
逐元素计算绝对值。 |
|
|
|
|
逐元素相加两个数组。 |
|
|
测试给定轴上的所有数组元素是否都评估为 True。 |
|
检查两个数组是否在容差范围内逐元素近似相等。 |
|
|
|
|
|
返回复数值数字或数组的角度。 |
|
测试给定轴上的任何数组元素是否评估为 True。 |
|
返回一个新数组,该数组在原始数组的末尾附加了值。 |
|
沿轴将函数应用于 1D 数组切片。 |
|
在指定的轴上重复应用函数。 |
|
创建均匀间隔值的数组。 |
|
计算输入的三角余弦的逐元素反函数。 |
|
计算输入的双曲余弦的逐元素反函数。 |
|
计算输入的三角正弦的逐元素反函数。 |
|
计算输入的双曲正弦的逐元素反函数。 |
|
计算输入的三角正切的逐元素反函数。 |
|
计算 x1/x2 的反正切,选择正确的象限。 |
|
计算输入的双曲正切的逐元素反函数。 |
|
返回数组最大值的索引。 |
|
返回数组最小值的索引。 |
|
返回部分排序数组的索引。 |
|
返回对数组进行排序的索引。 |
|
查找非零数组元素的索引 |
|
|
|
将对象转换为 JAX 数组。 |
|
检查两个数组是否逐元素相等。 |
|
检查两个数组是否逐元素相等。 |
|
返回数组的字符串表示形式。 |
|
将数组拆分为子数组。 |
|
返回数组中数据的字符串表示形式。 |
|
将对象转换为 JAX 数组。 |
|
|
|
|
|
将数组转换为指定的 dtype。 |
|
|
|
|
|
|
|
将输入转换为至少具有 1 个维度的数组。 |
|
将输入转换为至少具有 2 个维度的数组。 |
|
将输入转换为至少具有 3 个维度的数组。 |
|
计算加权平均值。 |
|
返回大小为 M 的 Bartlett 窗口。 |
|
计算整数数组中每个值的出现次数。 |
按元素计算按位与运算。 |
|
|
计算 |
|
|
|
|
|
|
按元素计算按位或运算。 |
|
|
|
按元素计算按位异或运算。 |
|
|
返回大小为 M 的 Blackman 窗口。 |
|
从块列表创建数组。 |
|
|
|
将数组广播到共同的形状。 |
|
将输入形状广播到共同的输出形状。 |
|
将数组广播到指定的形状。 |
沿最后一个轴连接切片、标量和类似数组的对象。 |
|
|
如果可以根据类型转换规则在数据类型之间进行转换,则返回 True。 |
|
计算输入数组的逐元素立方根。 |
|
|
|
将输入向上舍入到最接近的整数。 |
所有字符串标量类型的抽象基类。 |
|
|
通过堆叠选择数组的切片来构造数组。 |
|
将数组值裁剪到指定的范围。 |
|
按列堆叠数组。 |
|
|
|
类型为 complex128 的 JAX 标量构造函数。 |
|
类型为 complex64 的 JAX 标量构造函数。 |
所有由浮点数组成的复数标量类型的抽象基类。 |
|
将复数 dtype 转换为实数 dtype 时引发的警告。 |
|
|
使用布尔条件压缩沿给定轴的数组。 |
|
沿现有轴连接数组。 |
|
沿现有轴连接数组。 |
|
|
|
返回输入的逐元素复共轭。 |
|
两个一维数组的卷积。 |
|
返回数组的副本。 |
|
将 |
|
计算皮尔逊相关系数。 |
|
两个一维数组的相关性。 |
|
计算输入中每个元素的三角余弦值。 |
|
计算输入的双曲余弦值(逐元素计算)。 |
|
返回给定轴上非零元素的数量。 |
|
估计加权样本协方差。 |
|
计算两个数组的(批次)叉积。 |
|
|
|
沿轴的元素的累积乘积。 |
|
沿轴的元素的累积和。 |
|
沿数组轴的累积乘积。 |
|
沿数组轴的累积和。 |
|
将角度从度数转换为弧度。 |
|
|
|
从数组中删除一个或多个条目。 |
|
返回指定的对角线或构造对角线数组。 |
|
返回访问多维数组主对角线的索引。 |
|
返回访问给定数组主对角线的索引。 |
|
返回一个二维数组,其对角线上铺设了展平的输入数组。 |
|
返回数组的指定对角线。 |
|
计算沿给定轴的数组元素之间的 n 阶差分。 |
|
将数组转换为 bin 索引。 |
|
|
|
逐元素计算 x1 除以 x2 的整数商和余数 |
|
计算两个数组的点积。 |
|
|
|
按深度方向将数组拆分为子数组。 |
|
按深度方向堆叠数组。 |
|
创建数据类型对象。 |
|
计算展平数组的元素之差。 |
|
爱因斯坦求和 |
|
评估最佳收缩路径,而不评估 einsum。 |
|
创建一个空数组。 |
|
创建一个具有与数组相同形状和 dtype 的空数组。 |
|
返回 |
|
计算输入的逐元素指数。 |
|
计算输入的逐元素以 2 为底的指数。 |
|
将长度为 1 的维度插入数组 |
|
计算输入的每个元素的 |
|
返回满足条件的数组元素。 |
|
创建方形或矩形单位矩阵 |
|
计算实值输入的逐元素绝对值。 |
|
返回一个对角线被覆盖的数组副本。 |
|
浮点类型的机器限制。 |
|
将输入四舍五入到最接近零的整数。 |
|
返回展平数组中非零元素的索引 |
|
所有没有预定义长度的标量类型的抽象基类。 |
|
反转数组沿给定轴的元素顺序。 |
|
沿轴 1 反转数组元素的顺序。 |
|
沿轴 0 反转数组元素的顺序。 |
|
|
|
计算 |
|
一个 JAX 标量构造函数,类型为 float16。 |
|
一个 JAX 标量构造函数,类型为 float32。 |
|
一个 JAX 标量构造函数,类型为 float64。 |
|
所有浮点标量类型的抽象基类。 |
|
将输入向下舍入到最接近的整数。 |
|
计算 x1 除以 x2 的元素级向下取整除法 |
|
返回输入数组的元素级最大值。 |
|
返回输入数组的元素级最小值。 |
|
计算元素级浮点数模运算。 |
|
将浮点数值分割为尾数和 2 的指数。 |
|
将缓冲区转换为 1-D JAX 数组。 |
|
jnp.fromfile 的未实现 JAX 包装器。 |
|
从应用于索引的函数创建数组。 |
|
jnp.fromiter 的未实现 JAX 包装器。 |
|
从任意兼容 JAX 的标量函数创建 JAX ufunc。 |
|
将文本字符串转换为 1-D JAX 数组。 |
|
通过 DLPack 构建 JAX 数组。 |
|
创建一个充满指定值的数组。 |
|
创建一个与数组具有相同形状和 dtype 的,充满指定值的数组。 |
|
计算两个数组的最大公约数。 |
|
numpy 标量类型的基类。 |
|
生成以几何方式间隔的值。 |
返回当前的打印选项。 |
|
|
计算采样函数的数值梯度。 |
|
返回 |
|
返回 |
|
返回大小为 M 的汉明窗。 |
|
返回大小为 M 的汉宁窗。 |
|
计算 heaviside 阶跃函数。 |
|
计算一维直方图。 |
|
计算直方图的 bin 边缘。 |
|
计算二维直方图。 |
|
计算 N 维直方图。 |
|
将数组水平分割为子数组。 |
|
水平堆叠数组。 |
|
返回直角三角形给定边的元素级斜边。 |
|
计算第一类零阶修正贝塞尔函数。 |
|
创建一个方阵单位矩阵。 |
|
|
|
返回复数参数的元素级虚部。 |
一种构建数组索引元组的更好方法。 |
|
|
生成网格索引数组。 |
|
所有数值标量类型的抽象基类,其范围内的值具有(可能)不精确的表示形式,例如浮点数。 |
|
计算两个数组的内积。 |
|
在指定索引处将条目插入数组。 |
|
|
|
一个 JAX 标量构造函数,类型为 int16。 |
|
一个 JAX 标量构造函数,类型为 int32。 |
|
一个 JAX 标量构造函数,类型为 int64。 |
|
一个 JAX 标量构造函数,类型为 int8。 |
|
所有整数标量类型的抽象基类。 |
|
一维线性插值。 |
|
计算两个一维数组的集合交集。 |
|
计算输入的按位取反。 |
|
检查两个数组的元素是否在容差范围内近似相等。 |
|
返回一个布尔数组,显示输入是复数的位置。 |
|
检查输入是否为复数或包含复数元素的数组。 |
|
返回一个布尔值,指示提供的 dtype 是否属于指定的种类。 |
|
返回一个布尔数组,指示输入的每个元素是否为有限数。 |
|
确定 |
|
返回一个布尔数组,指示输入的每个元素是否为无穷大。 |
|
返回一个布尔数组,指示输入的每个元素是否为 |
|
返回一个布尔数组,指示输入的每个元素是否为负无穷大。 |
|
返回一个布尔数组,指示输入的每个元素是否为正无穷大。 |
|
返回一个布尔数组,显示输入是实数的位置。 |
|
检查输入是否不是复数或包含复数元素的数组。 |
|
如果输入是标量,则返回 True。 |
|
如果 arg1 在类型层次结构中等于或低于 arg2,则返回 True。 |
|
检查一个对象是否可以被迭代。 |
|
从 N 个一维序列返回一个多维网格(开放网格)。 |
|
返回大小为 M 的 Kaiser 窗口。 |
|
计算两个输入数组的克罗内克积。 |
|
计算两个数组的最小公倍数。 |
|
计算 x1 * 2 ** x2 |
|
将 |
|
返回 |
|
返回 |
|
按字典顺序对键序列进行排序。 |
|
返回间隔内均匀间隔的数字。 |
|
从 npy 文件加载 JAX 数组。 |
|
计算输入的按元素自然对数。 |
|
按元素计算 x 的以 10 为底的对数 |
|
计算一加输入的按元素对数, |
|
按元素计算 |
计算 |
|
以 2 为底的输入指数之和的对数,避免溢出。 |
|
按元素计算逻辑 AND 运算。 |
|
|
按元素计算 NOT bool(x)。 |
按元素计算逻辑 OR 运算。 |
|
按元素计算逻辑 XOR 运算。 |
|
|
生成对数间隔的值。 |
|
返回 (n, n) 数组掩码的索引。 |
|
执行矩阵乘法。 |
|
转置数组的最后两个维度。 |
|
批量矩阵向量乘积。 |
|
返回沿给定轴的数组元素的最大值。 |
|
返回输入数组的元素级最大值。 |
|
返回沿给定轴的数组元素的平均值。 |
|
返回沿给定轴的数组元素的中位数。 |
|
从 N 个一维向量构造 N 维网格数组。 |
返回密集的多维“meshgrid”。 |
|
|
返回沿给定轴的数组元素的最小值。 |
|
返回输入数组的元素级最小值。 |
|
|
|
返回输入数组的元素级小数部分和整数部分。 |
|
将数组轴移动到新位置 |
对两个数组执行元素级乘法。 |
|
|
替换数组中的 NaN 和无穷大条目。 |
|
返回数组最大值的索引,忽略 NaN。 |
|
返回数组最小值的索引,忽略 NaN。 |
|
沿轴的元素的累积积,忽略 NaN 值。 |
|
沿轴的元素的累积和,忽略 NaN 值。 |
|
返回沿给定轴的数组元素的最大值,忽略 NaN。 |
|
返回沿给定轴的数组元素的平均值,忽略 NaN。 |
|
返回沿给定轴的数组元素的中位数,忽略 NaN。 |
|
返回沿给定轴的数组元素的最小值,忽略 NaN。 |
|
计算沿指定轴的数据的百分位数,忽略 NaN 值。 |
|
返回沿给定轴的数组元素的乘积,忽略 NaN。 |
|
计算沿指定轴的数据的分位数,忽略 NaN。 |
|
计算沿给定轴的标准差,忽略 NaN。 |
|
返回沿给定轴的数组元素的总和,忽略 NaN。 |
|
计算沿给定轴的数组元素的方差,忽略 NaN。 |
|
|
|
返回数组的维度数。 |
返回输入的元素级负值。 |
|
|
返回元素级 |
|
返回数组的非零元素的索引。 |
|
返回 |
|
所有数字标量类型的抽象基类。 |
任何 Python 对象。 |
|
返回开放的多维“meshgrid”。 |
|
|
创建一个充满 1 的数组。 |
|
创建一个具有与数组相同形状和 dtype 的 1 数组。 |
|
计算两个数组的外积。 |
|
将位数组打包为 uint8 数组。 |
|
向数组添加填充。 |
|
返回数组的部分排序副本。 |
|
计算沿指定轴的数据的百分位数。 |
|
置换数组的轴/维度。 |
|
计算在域中分段定义的函数。 |
|
基于掩码更新数组元素。 |
|
返回给定根序列的多项式的系数。 |
|
返回两个多项式的和。 |
|
返回多项式的指定阶导数的系数。 |
|
返回多项式除法的商和余数。 |
|
数据的最小二乘多项式拟合。 |
|
返回多项式的指定阶积分的系数。 |
|
返回两个多项式的乘积。 |
|
返回两个多项式的差。 |
|
在特定值处计算多项式。 |
|
返回输入的按元素取正值的结果。 |
|
是 |
|
计算 |
|
用于设置打印选项的上下文管理器。 |
|
返回给定轴上数组元素的乘积。 |
|
返回二元运算应该将其参数转换为的类型。 |
|
返回给定轴上的峰峰值范围。 |
|
将元素放入给定索引处的数组中。 |
|
通过匹配一维索引和数据切片,将值放入目标数组。 |
|
计算沿指定轴的数据分位数。 |
沿第一个轴连接切片、标量和类数组对象。 |
|
|
将角度从弧度转换为度。 |
|
是 |
|
将数组展平为一维形状。 |
|
将多维索引转换为扁平索引。 |
|
返回复数参数的按元素取实部的值。 |
|
计算输入的按元素取倒数的值。 |
|
返回除法的按元素取余数的值。 |
|
从重复元素构造一个数组。 |
|
返回数组的重塑副本。 |
|
返回具有指定形状的新数组。 |
|
返回将 JAX 提升规则应用于输入的结果。 |
|
将 |
|
将 x 的元素四舍五入到最接近的整数。 |
|
沿指定轴滚动数组的元素。 |
|
将指定的轴滚动到给定的位置。 |
|
返回给定系数 |
|
在轴指定的平面中将数组逆时针旋转 90 度。 |
|
将输入均匀地四舍五入到给定的小数位数。 |
一种构建数组索引元组的更好方法。 |
|
|
以 NumPy |
|
将多个数组以未压缩的 |
|
在排序的数组中执行二分搜索。 |
|
根据一系列条件选择值。 |
|
设置打印选项。 |
|
计算两个一维数组的集合差。 |
|
计算两个数组中元素的集合异或。 |
|
返回数组的形状。 |
|
返回输入符号的按元素指示。 |
|
返回数组元素的符号位。 |
所有有符号整数标量类型的抽象基类。 |
|
|
计算输入的每个元素的三角正弦值。 |
|
计算归一化 sinc 函数。 |
是 |
|
|
计算输入的按元素双曲正弦值。 |
|
返回给定轴上的元素数量。 |
|
返回数组的排序副本。 |
|
返回复数数组的排序副本。 |
|
返回 |
|
将数组拆分为子数组。 |
|
计算输入数组的逐元素非负平方根。 |
|
计算输入数组的逐元素平方。 |
|
从数组中删除一个或多个长度为 1 的轴 |
|
沿着新轴连接数组。 |
|
计算给定轴上的标准差。 |
逐元素减去两个数组。 |
|
|
计算数组在给定轴上的元素之和。 |
|
交换数组的两个轴。 |
|
从数组中提取元素。 |
|
从数组中提取元素。 |
|
计算输入中每个元素的三角正切值。 |
|
计算输入的逐元素双曲正切值。 |
|
计算两个 N 维数组的张量点积。 |
|
通过沿指定维度重复 |
|
计算沿给定轴的输入对角线之和。 |
|
使用复合梯形规则沿给定轴进行积分。 |
|
返回 N 维数组的转置版本。 |
|
返回一个在对角线及其下方为 1,其他地方为 0 的数组。 |
|
返回数组的下三角。 |
|
返回大小为 |
|
返回给定数组的下三角索引。 |
|
裁剪输入数组的前导和/或尾随零。 |
|
返回数组的上三角。 |
|
返回大小为 |
|
返回给定数组的上三角索引。 |
|
计算 x1 除以 x2 的逐元素除法 |
|
将输入四舍五入到最接近零的整数。 |
|
在数组上逐元素操作的通用函数。 |
|
|
|
uint16 类型的 JAX 标量构造函数。 |
|
uint32 类型的 JAX 标量构造函数。 |
|
uint64 类型的 JAX 标量构造函数。 |
|
uint8 类型的 JAX 标量构造函数。 |
|
计算两个 1D 数组的并集。 |
|
返回数组中的唯一值。 |
|
返回 x 中的唯一值,以及索引、逆索引和计数。 |
|
返回 x 中的唯一值,以及计数。 |
|
返回 x 中的唯一值,以及索引、逆索引和计数。 |
|
返回 x 中的唯一值,以及索引、逆索引和计数。 |
|
解包 uint8 数组中的位。 |
|
将扁平索引转换为多维索引。 |
|
沿轴展开数组。 |
所有无符号整数标量类型的抽象基类。 |
|
|
展开周期性信号。 |
|
生成范德蒙矩阵。 |
|
计算给定轴上的方差。 |
|
执行两个 1D 向量的共轭乘法。 |
|
执行两个批量向量的共轭乘法。 |
|
批量共轭向量-矩阵乘积。 |
|
定义一个带有广播机制的向量化函数。 |
|
将数组垂直分割成子数组。 |
|
垂直堆叠数组。 |
|
根据条件从两个数组中选择元素。 |
|
创建一个填充零的数组。 |
|
创建一个与给定数组具有相同形状和数据类型的填充零的数组。 |
jax.numpy.fft#
|
计算沿给定轴的一维离散傅里叶变换。 |
|
计算沿给定轴的二维离散傅里叶变换。 |
|
返回离散傅里叶变换的采样频率。 |
|
计算沿给定轴的多维离散傅里叶变换。 |
|
将零频率 FFT 分量移动到频谱中心。 |
|
计算频谱具有埃尔米特对称性的数组的 1-D FFT。 |
|
计算一维逆离散傅里叶变换。 |
|
计算二维逆离散傅里叶变换。 |
|
计算多维逆离散傅里叶变换。 |
|
|
|
计算频谱具有埃尔米特对称性的数组的 1-D 逆 FFT。 |
|
计算实值一维逆离散傅里叶变换。 |
|
计算实值二维逆离散傅里叶变换。 |
|
计算实值多维逆离散傅里叶变换。 |
|
计算实值数组的一维离散傅里叶变换。 |
|
计算实值数组的二维离散傅里叶变换。 |
|
返回离散傅里叶变换的采样频率。 |
|
计算实值数组的多维离散傅里叶变换。 |
jax.numpy.linalg#
|
计算矩阵的 Cholesky 分解。 |
|
计算矩阵的条件数。 |
|
计算两个 3D 向量的叉积。 |
|
计算数组的行列式。 |
|
提取矩阵或矩阵堆栈的对角线。 |
|
计算方阵的特征值和特征向量。 |
|
计算厄米矩阵的特征值和特征向量。 |
|
计算一般矩阵的特征值。 |
|
计算厄米矩阵的特征值。 |
|
返回方阵的逆矩阵。 |
|
返回线性方程的最小二乘解。 |
|
执行矩阵乘法。 |
|
计算矩阵或矩阵堆栈的范数。 |
|
将方阵提升到整数次幂。 |
|
计算矩阵的秩。 |
|
转置矩阵或矩阵堆栈。 |
|
有效地计算一系列数组之间的矩阵乘积。 |
|
计算矩阵或向量的范数。 |
|
计算两个一维数组的外积。 |
|
计算矩阵的(Moore-Penrose)伪逆。 |
|
计算数组的 QR 分解。 |
|
计算数组的行列式的符号和(自然)对数。 |
|
求解线性方程组。 |
|
计算奇异值分解。 |
|
计算矩阵的奇异值。 |
|
计算两个 N 维数组的张量点积。 |
|
计算数组的张量逆。 |
|
求解张量方程 a x = b 中的 x。 |
|
计算矩阵的迹。 |
|
计算向量或一批向量的向量范数。 |
|
计算两个数组的(批量)向量共轭点积。 |
JAX 数组#
JAX 的 Array
(以及它的别名 jax.numpy.ndarray
)是 JAX 中的核心数组对象:你可以把它看作是 JAX 中与 numpy.ndarray
等价的对象。与 numpy.ndarray
一样,大多数用户不需要手动实例化 Array
对象,而是通过 jax.numpy
函数(例如 array()
、 arange()
、 linspace()
和上面列出的其他函数)来创建它们。
复制和序列化#
JAX Array
对象被设计为在适当情况下与 Python 标准库工具无缝协作。
使用内置的 copy
模块时,当 copy.copy()
或 copy.deepcopy()
遇到 Array
时,它等效于调用 copy()
方法,该方法将在与原始数组相同的设备上创建缓冲区的副本。这将在跟踪/JIT 编译的代码中正确工作,尽管在这种情况下,编译器可能会省略复制操作。
当内置的 pickle
模块遇到 Array
时,它将以类似于 pickled numpy.ndarray
对象的方式通过紧凑的位表示进行序列化。当反序列化时,结果将是在 *默认设备* 上的新 Array
对象。这是因为一般来说,序列化和反序列化可能发生在不同的运行时环境中,并且没有通用的方法可以将一个运行时的设备 ID 映射到另一个运行时的设备 ID。如果在跟踪/JIT 编译的代码中使用 pickle
,它将导致 ConcretizationTypeError
。
Python 数组 API 标准#
注意
在 JAX v0.4.32 之前,你必须 import jax.experimental.array_api
才能为 JAX 数组启用数组 API。在 JAX v0.4.32 之后,不再需要导入此模块,并且会引发弃用警告。在 JAX v0.5.0 之后,此导入将引发错误。
从 JAX v0.4.32 开始, jax.Array
和 jax.numpy
与 Python 数组 API 标准兼容。你可以通过 jax.Array.__array_namespace__()
访问数组 API 命名空间。
>>> def f(x):
... nx = x.__array_namespace__()
... return nx.sin(x) ** 2 + nx.cos(x) ** 2
>>> import jax.numpy as jnp
>>> x = jnp.arange(5)
>>> f(x).round()
Array([1., 1., 1., 1., 1.], dtype=float32)
JAX 在一些地方偏离了标准,主要是因为 JAX 数组是不可变的,不支持就地更新。其中一些不兼容性正在通过 array-api-compat 模块解决。
有关更多信息,请参阅 Python 数组 API 标准 文档。