jax.numpy
模块#
使用 jax.lax
中的原语实现 NumPy API。
虽然 JAX 尽量尽可能地遵循 NumPy API,但有时 JAX 无法完全遵循 NumPy。
值得注意的是,由于 JAX 数组是不可变的,因此无法在 JAX 中实现就地修改数组的 NumPy API。但是,JAX 通常能够提供纯函数式替代 API。例如,代替就地数组更新 (
x[i] = y
),JAX 提供了一个纯索引更新函数的替代方案x.at[i].set(y)
(参见ndarray.at
)。同样地,一些 NumPy 函数通常会在可能的情况下返回数组的视图(例如
transpose()
和reshape()
)。JAX 版本的这些函数将改为返回副本,尽管当使用jax.jit()
编译操作序列时,XLA 通常会优化掉这些副本。NumPy 在将值提升到
float64
类型方面非常积极。JAX 有时在类型提升方面不那么积极(参见 类型提升语义)。一些 NumPy 例程具有数据相关的输出形状(例如
unique()
和nonzero()
)。由于 XLA 编译器要求在编译时知道数组形状,因此此类操作与 JIT 不兼容。出于这个原因,JAX 为此类函数添加了一个可选的size
参数,该参数可以在静态中指定,以便与 JIT 一起使用。
几乎所有适用的 NumPy 函数都实现在 jax.numpy
命名空间中;它们列在下面。
索引更新功能的辅助属性。 |
|
|
|
|
逐元素计算绝对值。 |
|
|
|
|
逐元素相加两个数组。 |
|
|
测试给定轴上所有数组元素是否都计算为 True。 |
|
检查两个数组在容差范围内是否逐元素近似相等。 |
|
|
|
|
|
返回复数或复数数组的角度。 |
|
测试给定轴上任何数组元素是否都计算为 True。 |
|
返回一个新数组,其值附加到原始数组的末尾。 |
|
将函数应用于给定轴上的 1-D 切片。 |
|
在多个轴上重复应用函数。 |
|
创建一个等间距值的数组。 |
|
三角反余弦,逐元素计算。 |
|
反双曲余弦,逐元素计算。 |
|
反正弦,逐元素计算。 |
|
反双曲正弦,逐元素计算。 |
|
三角反正切,逐元素计算。 |
|
|
|
反双曲正切,逐元素计算。 |
|
返回沿轴的最大值的索引。 |
|
返回沿轴的最小值的索引。 |
|
返回部分排序数组的索引。 |
|
返回排序数组的索引。 |
|
查找非零数组元素的索引 |
|
|
|
将对象转换为 JAX 数组。 |
|
检查两个数组是否逐元素相等。 |
|
检查两个数组是否逐元素相等。 |
|
返回数组的字符串表示形式。 |
|
将数组拆分为子数组。 |
|
返回数组中数据的字符串表示形式。 |
|
将对象转换为 JAX 数组。 |
|
|
|
|
|
将数组转换为指定的 dtype。 |
|
|
|
|
|
|
将输入转换为至少具有 1 个维度的数组。 |
|
将输入转换为至少具有 2 个维度的数组。 |
|
将输入转换为至少具有 3 个维度的数组。 |
|
|
沿指定的轴计算加权平均值。 |
|
返回 Bartlett 窗。 |
|
计算整数数组中每个值的出现次数。 |
逐元素计算按位 AND 操作。 |
|
|
计算 |
|
计算按位取反,或按位非,逐元素操作。 |
|
将整数的位向左移动。 |
|
计算按位取反,或按位非,逐元素操作。 |
逐元素计算按位或运算。 |
|
|
|
逐元素计算按位异或运算。 |
|
|
返回Blackman窗。 |
|
从嵌套的块列表中组装一个nd数组。 |
|
|
|
将数组广播到一个共同的形状。 |
将输入形状广播到一个共同的输出形状。 |
|
|
将数组广播到指定的形状。 |
沿着最后一个轴连接切片、标量和类数组对象。 |
|
|
如果根据转换规则可以发生数据类型之间的转换,则返回True。 |
|
返回数组的立方根,逐元素操作。 |
|
|
|
向上舍入输入到最接近的整数。 |
所有字符字符串标量类型的抽象基类。 |
|
|
根据索引数组和要从中选择的数组列表构造一个数组。 |
|
将数组值剪辑到指定的范围。 |
|
按列堆叠数组。 |
|
|
|
|
|
|
所有由浮点数组成的复数标量类型的抽象基类。 |
|
将复数dtype转换为实数dtype时引发的警告。 |
|
|
使用布尔条件沿给定轴压缩数组。 |
|
沿着现有轴连接数组。 |
|
沿着现有轴连接数组。 |
|
返回复共轭,逐元素操作。 |
|
返回复共轭,逐元素操作。 |
|
两个一维数组的卷积。 |
|
返回数组的副本。 |
|
将 |
|
返回皮尔逊积矩相关系数。 |
|
两个一维数组的相关性。 |
|
计算输入每个元素的三角余弦。 |
|
双曲余弦,逐元素操作。 |
|
返回沿给定轴非零元素的数量。 |
|
根据数据和权重估计协方差矩阵。 |
|
返回两个(向量数组)的叉积。 |
|
|
|
沿轴计算元素的累积积。 |
|
沿轴计算元素的累积和。 |
|
沿数组的轴计算累积和。 |
|
将角度从度转换为弧度。 |
|
将角度从弧度转换为度。 |
|
从数组中删除条目或条目。 |
|
返回指定的对角线或构造对角线数组。 |
|
返回用于访问多维数组主对角线的索引。 |
|
返回用于访问给定数组主对角线的索引。 |
|
返回一个二维数组,其中扁平化的输入数组位于对角线上。 |
|
返回数组的指定对角线。 |
|
计算沿给定轴的数组元素之间的n阶差分。 |
|
返回输入数组中每个值所属的bin的索引。 |
|
|
|
逐元素计算x1除以x2的整数商和余数 |
|
计算两个数组的点积。 |
|
|
|
沿深度方向将数组分割成子数组。 |
|
沿深度方向堆叠数组。 |
|
创建一个数据类型对象。 |
|
计算扁平化数组元素的差值。 |
|
爱因斯坦求和 |
在不评估 einsum 的情况下评估最佳收缩路径。 |
|
|
创建一个空的数组。 |
|
创建一个与数组具有相同形状和数据类型的空数组。 |
|
逐元素返回 (x1 == x2)。 |
|
计算输入的逐元素指数。 |
|
计算输入的逐元素以 2 为底的指数。 |
|
在数组中插入长度为 1 的维度 |
|
计算输入每个元素的 |
|
返回满足条件的数组元素。 |
|
创建一个方阵或矩形单位矩阵 |
|
计算实值输入的逐元素绝对值。 |
|
返回一个数组的副本,其中对角线被覆盖。 |
|
浮点类型的机器限制。 |
|
将输入四舍五入到最接近零的整数。 |
|
返回扁平化数组中非零元素的索引 |
|
所有没有预定义长度的标量类型的抽象基类。 |
|
沿给定轴反转数组元素的顺序。 |
|
沿轴 1 反转数组元素的顺序。 |
|
沿轴 0 反转数组元素的顺序。 |
|
|
|
计算 |
|
|
|
|
|
|
|
所有浮点标量类型的抽象基类。 |
|
将输入向下舍入到最接近的整数。 |
|
逐元素计算 x1 除以 x2 的向下取整除法 |
|
返回输入数组的逐元素最大值。 |
|
返回输入数组的逐元素最小值。 |
|
返回除法的逐元素余数。 |
|
将 x 的元素分解成尾数和以 2 为底的指数。 |
|
将缓冲区转换为一维 JAX 数组。 |
|
jnp.fromfile 的未实现 JAX 包装器。 |
|
通过在每个坐标上执行函数来构造数组。 |
|
jnp.fromiter 的未实现 JAX 包装器。 |
|
从任意兼容 JAX 的标量函数创建 JAX ufunc。 |
|
将文本字符串转换为一维 JAX 数组。 |
|
通过 DLPack 构造 JAX 数组。 |
|
创建一个充满指定值的数组。 |
|
创建一个充满指定值的数组,其形状和数据类型与数组相同。 |
|
计算两个数组的最大公约数。 |
|
numpy 标量类型的基类。 |
|
返回对数刻度上均匀间隔的数字(几何级数)。 |
返回当前打印选项。 |
|
|
返回 N 维数组的梯度。 |
|
返回 |
|
返回 |
|
返回汉明窗。 |
|
返回汉宁窗。 |
|
计算海维赛德阶跃函数。 |
|
计算数据集的直方图。 |
|
仅计算 histogram 使用的 bin 边缘的函数 |
|
计算两个数据样本的二维直方图。 |
|
计算某些数据的多分量直方图。 |
|
水平分割数组为子数组。 |
|
水平堆叠数组。 |
|
给定直角三角形的“直角边”,返回其斜边。 |
第一类修正贝塞尔函数,阶数为0。 |
|
|
创建一个正方形单位矩阵 |
|
|
|
返回复数参数的虚部。 |
一种更友好的构建数组索引元组的方式。 |
|
|
返回表示网格索引的数组。 |
|
所有数值标量类型的抽象基类,其值范围(可能)以不精确的方式表示,例如浮点数。 |
|
计算两个数组的内积。 |
|
在给定轴上的给定索引之前插入值。 |
|
|
|
|
|
|
|
|
|
|
|
所有整数标量类型的抽象基类。 |
|
单维线性插值,适用于单调递增的样本点。 |
|
计算两个一维数组的交集。 |
|
计算按位取反,或按位非,逐元素操作。 |
|
检查两个数组的元素在容差范围内是否近似相等。 |
|
返回布尔数组,显示输入在何处为复数。 |
|
检查输入是否为复数或包含复数元素的数组。 |
|
返回一个布尔值,指示提供的 dtype 是否为指定类型。 |
|
逐元素测试有限性(不是无穷大,也不是非数字)。 |
|
确定 |
|
逐元素测试正无穷或负无穷。 |
|
逐元素测试 NaN 并将结果作为布尔数组返回。 |
|
逐元素测试负无穷,将结果作为布尔数组返回。 |
|
逐元素测试负无穷,将结果作为布尔数组返回。 |
|
返回布尔数组,显示输入在何处为实数。 |
|
检查输入是否不是复数或包含复数元素的数组。 |
|
如果element的类型是标量类型,则返回 True。 |
|
如果第一个参数是类型层次结构中较低/相等的类型代码,则返回 True。 |
|
检查对象是否可以迭代。 |
|
从 N 个一维序列返回一个多维网格(开放网格)。 |
|
返回凯泽窗。 |
|
计算两个输入数组的克罗内克积。 |
|
计算两个数组的最小公倍数。 |
|
逐元素返回 x1 * 2**x2。 |
|
将整数的位向左移动。 |
|
返回 |
|
返回 |
|
使用一系列键执行间接稳定排序。 |
|
在区间内返回等间距的数字。 |
|
从 |
|
计算输入的逐元素自然对数。 |
|
逐元素计算 x 的以 10 为底的对数 |
|
计算一加输入的逐元素对数, |
|
逐元素计算 |
|
计算 |
|
以 2 为底计算输入指数和的对数,避免溢出。 |
逐元素计算逻辑 AND 运算。 |
|
|
逐元素计算 NOT x 的真值。 |
逐元素计算逻辑 OR 运算。 |
|
逐元素计算逻辑异或运算。 |
|
|
返回对数刻度上均匀分布的数字。 |
|
给定一个掩码函数,返回访问 (n, n) 数组的索引。 |
|
执行矩阵乘法。 |
|
转置数组的最后两个维度。 |
|
沿给定轴返回数组元素的最大值。 |
|
返回输入数组的逐元素最大值。 |
|
沿给定轴返回数组元素的平均值。 |
|
沿给定轴返回数组元素的中位数。 |
|
从坐标向量返回坐标矩阵的元组。 |
返回密集的多维“网格”。 |
|
|
沿给定轴返回数组元素的最小值。 |
|
返回输入数组的逐元素最小值。 |
|
返回除法的逐元素余数。 |
|
逐元素返回数组的小数部分和整数部分。 |
|
将数组轴移动到新位置 |
逐元素将两个数组相乘。 |
|
|
将 NaN 替换为零,将无穷大替换为大的有限数字(默认 |
|
返回指定轴上最大值的索引,忽略 |
|
返回指定轴上最小值的索引,忽略 |
|
沿轴计算元素的累积积,忽略 NaN 值。 |
|
沿轴计算元素的累积和,忽略 NaN 值。 |
|
沿给定轴返回数组元素的最大值,忽略 NaN。 |
|
沿给定轴返回数组元素的平均值,忽略 NaN。 |
|
沿给定轴返回数组元素的中位数,忽略 NaN。 |
|
沿给定轴返回数组元素的最小值,忽略 NaN。 |
|
计算数据沿指定轴的百分位数,忽略 NaN 值。 |
|
沿给定轴返回数组元素的乘积,忽略 NaN。 |
|
计算数据沿指定轴的分位数,忽略 NaN。 |
|
沿给定轴计算标准差,忽略 NaN。 |
|
沿给定轴返回数组元素的和,忽略 NaN。 |
|
沿给定轴计算数组元素的方差,忽略 NaN。 |
|
|
|
返回数组的维度数。 |
|
返回输入的逐元素负值。 |
|
返回 |
|
返回数组中非零元素的索引。 |
|
返回 (x1 != x2) 的逐元素结果。 |
|
所有数值标量类型的抽象基类。 |
任何 Python 对象。 |
|
返回开放的多维“网格”。 |
|
|
创建一个全为一的数组。 |
|
创建一个与数组具有相同形状和数据类型的全为一的数组。 |
|
计算两个数组的外积。 |
|
将二值数组的元素打包到 uint8 数组的位中。 |
|
填充数组。 |
|
返回数组的部分排序副本。 |
|
计算数据沿指定轴的百分位数。 |
|
置换数组的轴/维度。 |
|
评估跨域定义的分段函数。 |
|
基于掩码更新数组元素。 |
|
返回给定根序列的多项式的系数。 |
|
返回两个多项式的和。 |
|
返回指定阶多项式导数的系数。 |
|
返回多项式除法的商和余数。 |
|
对数据进行最小二乘多项式拟合。 |
|
返回指定阶多项式积分的系数。 |
|
返回两个多项式的乘积。 |
|
返回两个多项式的差。 |
|
在特定值处计算多项式的值。 |
|
返回输入的逐元素正值。 |
|
|
|
计算 |
|
用于设置打印选项的上下文管理器。 |
|
返回沿给定轴的数组元素的乘积。 |
|
返回二元运算应将其参数转换为的类型。 |
|
返回沿给定轴的峰峰值范围。 |
|
将元素放入给定索引处的数组中。 |
|
沿指定轴计算数据的分位数。 |
沿第一个轴连接切片、标量和类似数组的对象。 |
|
|
将角度从弧度转换为度。 |
|
将角度从度转换为弧度。 |
|
将数组展平为一维形状。 |
|
将多维索引转换为扁平索引。 |
|
返回复数参数的实部。 |
|
返回参数的倒数,逐元素计算。 |
|
返回除法的逐元素余数。 |
|
从重复的元素构造一个数组。 |
|
返回数组的重新整形副本。 |
|
返回具有指定形状的新数组。 |
|
返回应用 NumPy |
|
将 |
|
将 x 的元素四舍五入到最接近的整数 |
|
沿指定轴滚动数组的元素。 |
|
将指定的轴滚动到给定位置。 |
|
给定系数 |
|
将数组在由轴指定的平面中逆时针旋转 90 度。 |
|
将输入均匀四舍五入到给定的十进制位数。 |
|
将输入均匀四舍五入到给定的十进制位数。 |
一种更友好的构建数组索引元组的方式。 |
|
|
将数组保存到 NumPy |
|
将多个数组保存到单个未压缩的 |
|
在已排序的数组中执行二分查找。 |
|
根据一系列条件选择值。 |
|
设置打印选项。 |
|
计算两个一维数组的集合差。 |
|
计算两个数组中元素的集合异或。 |
|
返回数组的形状。 |
|
返回输入的符号的逐元素指示。 |
|
在符号位设置(小于零)的位置返回逐元素的 True。 |
所有带符号整数标量类型的抽象基类。 |
|
|
计算输入的每个元素的三角正弦。 |
|
返回归一化的 sinc 函数。 |
|
|
|
双曲正弦,逐元素计算。 |
|
返回沿给定轴的元素数量。 |
|
返回数组的已排序副本。 |
|
根据实部,然后根据虚部对复杂数组进行排序。 |
|
将数组拆分为子数组。 |
|
逐元素返回数组的非负平方根。 |
|
返回输入的逐元素平方。 |
|
从数组中删除一个或多个长度为 1 的轴。 |
|
沿着新轴连接数组。 |
|
计算沿给定轴的标准差。 |
|
逐元素减去参数。 |
|
沿给定轴计算数组元素的和。 |
|
交换数组的两个轴。 |
|
从数组中获取元素。 |
|
从数组中获取元素。 |
|
计算输入每个元素的三角正切。 |
|
逐元素计算双曲正切。 |
|
计算两个 N 维数组的张量点积。 |
|
通过重复 A reps 指定的次数来构造一个数组。 |
|
返回数组对角线元素的和。 |
|
使用复合梯形规则沿给定轴积分。 |
|
返回 N 维数组的转置版本。 |
|
返回一个数组,在对角线及其下方为 1,其他位置为 0。 |
|
返回数组的下三角部分。 |
|
返回大小为 |
|
返回给定数组下三角部分的索引。 |
|
修剪输入数组开头和/或结尾的零。 |
|
返回数组的上三角部分。 |
|
返回大小为 |
|
返回给定数组上三角部分的索引。 |
|
逐元素计算 x1 除以 x2 的结果。 |
|
将输入四舍五入到最接近零的整数。 |
|
通用函数,对数组逐元素进行运算。 |
|
|
|
|
|
|
|
|
|
|
|
计算两个一维数组的并集。 |
|
返回数组中唯一的值。 |
|
返回 x 中唯一的值,以及索引、反向索引和计数。 |
|
返回 x 中唯一的值以及计数。 |
|
返回 x 中唯一的值,以及索引、反向索引和计数。 |
|
返回 x 中唯一的值,以及索引、反向索引和计数。 |
|
将 uint8 数组的元素解包到二进制值的输出数组中。 |
|
将扁平索引转换为多维索引。 |
|
沿轴解开数组。 |
所有无符号整数标量类型的抽象基类。 |
|
|
通过取相对于周期的较大增量的补码来解开。 |
|
生成范德蒙矩阵。 |
|
计算沿给定轴的方差。 |
|
执行两个一维向量的共轭乘法。 |
|
执行两个批处理向量的共轭乘法。 |
|
定义具有广播功能的向量化函数。 |
|
垂直拆分数组为子数组。 |
|
垂直堆叠数组。 |
|
根据条件从两个数组中选择元素。 |
|
创建一个全零数组。 |
|
创建一个与给定数组形状和数据类型相同的全零数组。 |
jax.numpy.fft#
|
沿给定轴计算一维离散傅里叶变换。 |
|
沿给定轴计算二维离散傅里叶变换。 |
|
返回离散傅里叶变换的采样频率。 |
|
沿给定轴计算多维离散傅里叶变换。 |
|
将零频率分量移到频谱的中心。 |
|
计算一个数组的一维 FFT,该数组的频谱具有厄米对称性。 |
|
计算一维逆离散傅里叶变换。 |
|
计算二维逆离散傅里叶变换。 |
|
计算多维逆离散傅里叶变换。 |
|
是fftshift的逆运算。 |
|
计算一个数组的一维逆 FFT,该数组的频谱具有厄米对称性。 |
|
计算实值一维逆离散傅里叶变换。 |
|
计算实值二维逆离散傅里叶变换。 |
|
计算实值多维逆离散傅里叶变换。 |
|
计算实值数组的一维离散傅里叶变换。 |
|
计算实值数组的二维离散傅里叶变换。 |
|
返回离散傅里叶变换的采样频率。 |
|
计算实值数组的多维离散傅里叶变换。 |
jax.numpy.linalg#
|
计算矩阵的Cholesky分解。 |
|
计算矩阵的条件数。 |
|
计算两个3D向量的叉积。 |
|
计算数组的行列式。 |
|
提取矩阵或矩阵堆栈的对角线。 |
|
计算方阵的特征值和特征向量。 |
|
计算厄米矩阵的特征值和特征向量。 |
|
计算一般矩阵的特征值。 |
|
计算厄米矩阵的特征值。 |
|
返回方阵的逆矩阵。 |
|
返回线性方程组的最小二乘解。 |
|
执行矩阵乘法。 |
|
计算矩阵或矩阵堆栈的范数。 |
|
将方阵提升到整数次幂。 |
|
计算矩阵的秩。 |
|
转置矩阵或矩阵堆栈。 |
|
有效地计算一系列数组之间的矩阵乘积。 |
|
计算矩阵或向量的范数。 |
|
计算两个一维数组的外积。 |
|
计算矩阵的(Moore-Penrose)伪逆。 |
|
计算数组的QR分解。 |
|
计算数组的行列式的符号和(自然)对数。 |
|
解线性方程组。 |
|
计算奇异值分解。 |
|
计算矩阵的奇异值。 |
|
计算两个 N 维数组的张量点积。 |
|
计算数组的张量逆。 |
|
求解张量方程 a x = b 中的 x。 |
|
计算矩阵的迹。 |
|
计算向量或向量批次的向量范数。 |
|
计算两个数组的(批处理)向量共轭点积。 |
JAX 数组#
JAX 的 Array
(及其别名 jax.numpy.ndarray
)是 JAX 中的核心数组对象:您可以将其视为 JAX 中等效于 numpy.ndarray
的对象。与 numpy.ndarray
一样,大多数用户无需手动实例化 Array
对象,而是通过 jax.numpy
中的函数(如 array()
、arange()
、linspace()
等,如上所述)来创建它们。
复制和序列化#
JAX Array
对象旨在在适当的情况下与 Python 标准库工具无缝协作。
使用内置的 copy
模块,当 copy.copy()
或 copy.deepcopy()
遇到 Array
时,它等效于调用 copy()
方法,这将在与原始数组相同的设备上创建缓冲区的副本。这将在跟踪/JIT 编译的代码中正常工作,尽管在这种情况下,编译器可能会省略复制操作。
当内置的 pickle
模块遇到 Array
时,它将通过类似于腌制 numpy.ndarray
对象的紧凑位表示形式进行序列化。取消腌制后,结果将是一个新的 Array
对象,**位于默认设备上**。这是因为通常,腌制和取消腌制可能发生在不同的运行时环境中,并且没有普遍的方法将一个运行时的设备 ID 映射到另一个运行时的设备 ID。如果在跟踪/JIT 编译的代码中使用 pickle
,则会导致 ConcretizationTypeError
。
Python 数组 API 标准#
注意
在 JAX v0.4.32 之前,您必须 import jax.experimental.array_api
才能为 JAX 数组启用数组 API。在 JAX v0.4.32 之后,不再需要导入此模块,并且会引发弃用警告。
从 JAX v0.4.32 开始,jax.Array
和 jax.numpy
与 Python 数组 API 标准 兼容。您可以通过 jax.Array.__array_namespace__()
访问数组 API 命名空间。
>>> def f(x):
... nx = x.__array_namespace__()
... return nx.sin(x) ** 2 + nx.cos(x) ** 2
>>> import jax.numpy as jnp
>>> x = jnp.arange(5)
>>> f(x).round()
Array([1., 1., 1., 1., 1.], dtype=float32)
JAX 在少数几个地方偏离了标准,主要是因为 JAX 数组是不可变的,不支持就地更新。其中一些不兼容性正在通过 array-api-compat 模块得到解决。
有关更多信息,请参阅 Python 数组 API 标准 文档。