jax.numpy
模块#
使用 jax.lax
中的原语实现 NumPy API。
虽然 JAX 尝试尽可能紧密地遵循 NumPy API,但有时 JAX 无法完全遵循 NumPy。
值得注意的是,由于 JAX 数组是不可变的,NumPy 中会原地修改数组的 API 无法在 JAX 中实现。然而,JAX 通常能够提供一种纯函数式的替代 API。例如,JAX 提供了一个纯粹的索引更新函数
x.at[i].set(y)
,而不是原地数组更新 (x[i] = y
)(请参阅ndarray.at
)。类似地,一些 NumPy 函数在可能的情况下会返回数组的视图(例如
transpose()
和reshape()
)。JAX 的这些函数版本将返回副本,但是当使用jax.jit()
编译操作序列时,XLA 通常会优化掉这些副本。NumPy 在将值提升为
float64
类型方面非常积极。JAX 在类型提升方面有时没有那么积极(请参阅 类型提升语义)。一些 NumPy 例程具有数据相关的输出形状(例如
unique()
和nonzero()
)。由于 XLA 编译器要求数组形状在编译时已知,因此此类操作与 JIT 不兼容。因此,JAX 为此类函数添加了一个可选的size
参数,可以在静态指定此参数的情况下将它们与 JIT 一起使用。
几乎所有适用的 NumPy 函数都在 jax.numpy
命名空间中实现;它们在下面列出。
用于索引更新功能的辅助属性。 |
|
|
|
|
计算逐元素的绝对值。 |
|
|
|
|
逐元素地相加两个数组。 |
|
|
测试给定轴上的所有数组元素是否都计算为 True。 |
|
检查两个数组在容差范围内是否逐元素近似相等。 |
|
|
|
|
|
返回复数值数字或数组的角度。 |
|
测试给定轴上的任何数组元素是否计算为 True。 |
|
返回一个新数组,其中值附加到原始数组的末尾。 |
|
沿轴将函数应用于一维数组切片。 |
|
在指定的轴上重复应用函数。 |
|
创建均匀间隔的值的数组。 |
|
计算输入的三角余弦的逐元素逆。 |
|
计算输入的双曲余弦的逐元素逆。 |
|
计算输入的三角正弦的逐元素逆。 |
|
计算输入的双曲正弦的逐元素逆。 |
|
计算输入的三角正切的逐元素逆。 |
|
计算 x1/x2 的反正切,选择正确的象限。 |
|
计算输入的双曲正切的逐元素逆。 |
|
返回数组最大值的索引。 |
|
返回数组最小值的索引。 |
|
返回部分排序数组的索引。 |
|
返回对数组排序的索引。 |
|
查找非零数组元素的索引 |
|
|
|
将对象转换为 JAX 数组。 |
|
检查两个数组是否逐元素相等。 |
|
检查两个数组是否逐元素相等。 |
|
返回数组的字符串表示形式。 |
|
将数组拆分为子数组。 |
|
返回数组中数据的字符串表示形式。 |
|
将对象转换为 JAX 数组。 |
|
|
|
|
|
将数组转换为指定的 dtype。 |
|
|
|
|
|
|
|
将输入转换为至少一维的数组。 |
|
将输入转换为至少二维的数组。 |
|
将输入转换为至少三维的数组。 |
|
计算加权平均值。 |
|
返回大小为 M 的 Bartlett 窗口。 |
|
计算整数数组中每个值的出现次数。 |
逐元素计算按位与运算。 |
|
|
计算 |
|
|
|
|
|
|
逐元素计算按位或运算。 |
|
|
|
逐元素计算按位异或运算。 |
|
|
返回大小为 M 的 Blackman 窗口。 |
|
从块列表创建数组。 |
|
|
|
将数组广播到共同的形状。 |
|
将输入形状广播到共同的输出形状。 |
|
将数组广播到指定的形状。 |
沿最后一个轴连接切片、标量和类数组对象。 |
|
|
如果根据转换规则可以进行数据类型之间的转换,则返回 True。 |
|
计算输入数组的逐元素立方根。 |
|
|
|
将输入向上舍入到最接近的整数。 |
所有字符串标量类型的抽象基类。 |
|
|
通过堆叠选择数组的切片来构造数组。 |
|
将数组值剪切到指定范围。 |
|
按列堆叠数组。 |
|
|
|
类型为 complex128 的 JAX 标量构造函数。 |
|
类型为 complex64 的 JAX 标量构造函数。 |
所有由浮点数组成的复数标量类型的抽象基类。 |
|
将复数 dtype 转换为实数 dtype 时引发的警告。 |
|
|
使用布尔条件沿给定轴压缩数组。 |
|
沿现有轴连接数组。 |
|
沿现有轴连接数组。 |
|
|
|
返回输入的逐元素复共轭。 |
|
两个一维数组的卷积。 |
|
返回数组的副本。 |
|
将 |
|
计算皮尔逊相关系数。 |
|
两个一维数组的相关性。 |
|
计算输入的每个元素的三角余弦值。 |
|
计算输入的逐元素双曲余弦。 |
|
返回给定轴上非零元素的数量。 |
|
估计加权样本协方差。 |
|
计算两个数组的(批处理)叉积。 |
|
|
|
沿轴的元素的累积乘积。 |
|
沿轴的元素的累积和。 |
|
沿数组轴的累积乘积。 |
|
沿数组轴计算累积和。 |
|
将角度从度数转换为弧度。 |
|
|
|
从数组中删除一个或多个条目。 |
|
返回指定的对角线或构造一个对角线数组。 |
|
返回用于访问多维数组主对角线的索引。 |
|
返回用于访问给定数组主对角线的索引。 |
|
返回一个 2-D 数组,其对角线上铺设了展平的输入数组。 |
|
返回数组的指定对角线。 |
|
计算沿给定轴的数组元素之间的 n 阶差分。 |
|
将数组转换为 bin 索引。 |
|
|
|
按元素方式计算 x1 除以 x2 的整数商和余数 |
|
计算两个数组的点积。 |
|
|
|
沿深度方向将数组拆分为子数组。 |
|
沿深度方向堆叠数组。 |
|
创建数据类型对象。 |
|
计算展平数组的元素的差值。 |
|
爱因斯坦求和 |
|
计算最优收缩路径,而不计算 einsum。 |
|
创建一个空数组。 |
|
创建一个与数组具有相同形状和数据类型的空数组。 |
|
返回 |
|
计算输入的按元素指数。 |
|
计算输入的按元素 2 为底的指数。 |
|
将长度为 1 的维度插入数组 |
|
计算输入中每个元素的 |
|
返回满足条件的数组的元素。 |
|
创建方形或矩形单位矩阵 |
|
计算实值输入的按元素绝对值。 |
|
返回一个数组的副本,其中对角线被覆盖。 |
|
浮点类型的机器限制。 |
|
将输入四舍五入到最接近于零的整数。 |
|
返回展平数组中非零元素的索引 |
|
所有没有预定义长度的标量类型的抽象基类。 |
|
沿给定轴反转数组元素的顺序。 |
|
沿轴 1 反转数组元素的顺序。 |
|
沿轴 0 反转数组元素的顺序。 |
|
|
|
计算 |
|
float16 类型的 JAX 标量构造函数。 |
|
float32 类型的 JAX 标量构造函数。 |
|
float64 类型的 JAX 标量构造函数。 |
|
所有浮点标量类型的抽象基类。 |
|
将输入四舍五入到最接近的向下整数。 |
|
按元素方式计算 x1 除以 x2 的向下除法 |
|
返回输入数组的按元素最大值。 |
|
返回输入数组的按元素最小值。 |
|
计算按元素浮点模运算。 |
|
将浮点值拆分为尾数和 2 的指数。 |
|
将缓冲区转换为 1-D JAX 数组。 |
|
jnp.fromfile 的未实现 JAX 封装。 |
|
从应用于索引的函数创建数组。 |
|
jnp.fromiter 的未实现 JAX 封装。 |
|
从任意 JAX 兼容的标量函数创建 JAX ufunc。 |
|
将文本字符串转换为 1 维 JAX 数组。 |
|
通过 DLPack 构建 JAX 数组。 |
|
创建一个填充指定值的数组。 |
|
创建一个具有与数组相同形状和 dtype 的、填充指定值的数组。 |
|
计算两个数组的最大公约数。 |
|
numpy 标量类型的基类。 |
|
生成几何间隔的值。 |
返回当前打印选项。 |
|
|
计算采样函数的数值梯度。 |
|
返回 |
|
返回 |
|
返回大小为 M 的汉明窗。 |
|
返回大小为 M 的汉宁窗。 |
|
计算赫维赛德阶跃函数。 |
|
计算一维直方图。 |
|
计算直方图的 bin 边缘。 |
|
计算二维直方图。 |
|
计算 N 维直方图。 |
|
将数组水平分割为子数组。 |
|
水平堆叠数组。 |
|
返回直角三角形给定边长的逐元素斜边。 |
|
计算第一类零阶修正贝塞尔函数。 |
|
创建方形单位矩阵。 |
|
|
|
返回复数参数的逐元素虚部。 |
一种构建数组索引元组的更友好的方法。 |
|
|
生成网格索引数组。 |
|
所有数值标量类型的抽象基类,其范围内的值具有(可能)不精确的表示,例如浮点数。 |
|
计算两个数组的内积。 |
|
在指定索引处将条目插入到数组中。 |
|
|
|
类型为 int16 的 JAX 标量构造函数。 |
|
类型为 int32 的 JAX 标量构造函数。 |
|
类型为 int64 的 JAX 标量构造函数。 |
|
类型为 int8 的 JAX 标量构造函数。 |
|
所有整数标量类型的抽象基类。 |
|
一维线性插值。 |
|
计算两个一维数组的集合交集。 |
|
计算输入的按位反转。 |
|
检查两个数组的元素是否在容差范围内近似相等。 |
|
返回显示输入是否为复数的布尔数组。 |
|
检查输入是否为复数或包含复数元素的数组。 |
|
返回一个布尔值,指示提供的 dtype 是否为指定的类型。 |
|
返回一个布尔数组,指示输入的每个元素是否为有限值。 |
|
确定 |
|
返回一个布尔数组,指示输入的每个元素是否为无限值。 |
|
返回一个布尔数组,指示输入的每个元素是否为 |
|
返回一个布尔数组,指示输入的每个元素是否为负无穷大。 |
|
返回一个布尔数组,指示输入中的每个元素是否为正无穷大。 |
|
返回一个布尔数组,显示输入是否为实数。 |
|
检查输入是否不是复数或包含复数元素的数组。 |
|
如果输入是标量,则返回 True。 |
|
如果 arg1 在类型层次结构中等于或低于 arg2,则返回 True。 |
|
检查对象是否可以迭代。 |
|
从 N 个一维序列返回一个多维网格(开放网格)。 |
|
返回大小为 M 的 Kaiser 窗口。 |
|
计算两个输入数组的克罗内克积。 |
|
计算两个数组的最小公倍数。 |
|
计算 x1 * 2 ** x2 |
|
将 |
|
返回 |
|
返回 |
|
按字典顺序对键序列进行排序。 |
|
返回间隔内的均匀间隔的数字。 |
|
从 npy 文件加载 JAX 数组。 |
|
计算输入的按元素自然对数。 |
|
按元素计算 x 的以 10 为底的对数 |
|
按元素计算 1 加输入的对数, |
|
按元素计算 |
计算 |
|
以 2 为底的输入指数和的对数,避免溢出。 |
|
按元素计算逻辑与运算。 |
|
|
按元素计算 NOT bool(x)。 |
按元素计算逻辑或运算。 |
|
按元素计算逻辑异或运算。 |
|
|
生成对数间隔的值。 |
|
返回 (n, n) 数组的掩码的索引。 |
|
执行矩阵乘法。 |
|
转置数组的最后两个维度。 |
|
批量矩阵向量积。 |
|
返回沿给定轴的数组元素的最大值。 |
|
返回输入数组的按元素最大值。 |
|
返回沿给定轴的数组元素的平均值。 |
|
返回沿给定轴的数组元素的中位数。 |
|
从 N 个一维向量构造 N 维网格数组。 |
返回密集的 多维 “meshgrid”。 |
|
|
返回沿给定轴的数组元素的最小值。 |
|
返回输入数组的按元素最小值。 |
|
|
|
返回输入数组的按元素分数部分和整数部分。 |
|
将数组轴移动到新位置 |
按元素方式将两个数组相乘。 |
|
|
替换数组中的 NaN 和无穷大条目。 |
|
返回数组最大值的索引,忽略 NaN。 |
|
返回数组最小值的索引,忽略 NaN。 |
|
沿轴的元素的累积积,忽略 NaN 值。 |
|
沿轴的元素的累积和,忽略 NaN 值。 |
|
返回沿给定轴的数组元素的最大值,忽略 NaN。 |
|
返回沿给定轴的数组元素的平均值,忽略 NaN。 |
|
返回沿给定轴的数组元素的中位数,忽略 NaN。 |
|
返回给定轴上数组元素的最小值,忽略 NaN 值。 |
|
计算沿指定轴的数据的百分位数,忽略 NaN 值。 |
|
返回给定轴上数组元素的乘积,忽略 NaN 值。 |
|
计算沿指定轴的数据的分位数,忽略 NaN 值。 |
|
计算沿给定轴的标准差,忽略 NaN 值。 |
|
返回给定轴上数组元素的总和,忽略 NaN 值。 |
|
计算沿给定轴的数组元素的方差,忽略 NaN 值。 |
|
|
|
返回数组的维度数。 |
返回输入的逐元素负值。 |
|
|
返回元素方向上 |
|
返回数组中非零元素的索引。 |
|
返回 |
|
所有数值标量类型的抽象基类。 |
任何 Python 对象。 |
|
返回开放的多维“meshgrid”。 |
|
|
创建一个充满 1 的数组。 |
|
创建一个与数组具有相同形状和 dtype 的全 1 数组。 |
|
计算两个数组的外积。 |
|
将位数组打包成 uint8 数组。 |
|
向数组添加填充。 |
|
返回数组的部分排序副本。 |
|
计算沿指定轴的数据的百分位数。 |
|
置换数组的轴/维度。 |
|
计算在域中分段定义的函数。 |
|
基于掩码更新数组元素。 |
|
返回给定根序列的多项式的系数。 |
|
返回两个多项式的和。 |
|
返回指定阶数的多项式导数的系数。 |
|
返回多项式除法的商和余数。 |
|
数据最小二乘多项式拟合。 |
|
返回指定阶数的多项式积分的系数。 |
|
返回两个多项式的乘积。 |
|
返回两个多项式的差。 |
|
计算多项式在特定值上的值。 |
|
返回输入的逐元素正值。 |
|
|
|
计算 |
|
用于设置打印选项的上下文管理器。 |
|
返回给定轴上数组元素的乘积。 |
|
返回二元运算应将其参数强制转换的类型。 |
|
返回沿给定轴的峰峰值范围。 |
|
将元素放入给定索引的数组中。 |
|
通过匹配 1 维索引和数据切片,将值放入目标数组中。 |
|
计算沿指定轴的数据分位数。 |
沿第一个轴连接切片、标量和类数组对象。 |
|
|
将角度从弧度转换为度。 |
|
|
|
将数组展平为一维形状。 |
|
将多维索引转换为扁平索引。 |
|
返回复数参数的逐元素实部。 |
|
计算输入的逐元素倒数。 |
|
返回除法的逐元素余数。 |
|
从重复的元素构造一个数组。 |
|
返回数组的重塑副本。 |
|
返回具有指定形状的新数组。 |
|
返回将 JAX 提升规则应用于输入的结果。 |
|
将 |
|
将 x 的元素四舍五入到最接近的整数。 |
|
沿指定轴滚动数组的元素。 |
|
将指定的轴滚动到给定的位置。 |
|
返回给定系数 |
|
在 axes 指定的平面中,将数组逆时针旋转 90 度。 |
|
将输入均匀地舍入到给定的小数位数。 |
一种构建数组索引元组的更友好的方法。 |
|
|
以 NumPy |
|
以未压缩的 |
|
在排序数组中执行二分查找。 |
|
根据一系列条件选择值。 |
|
设置打印选项。 |
|
计算两个 1D 数组的集合差。 |
|
计算两个数组中元素的集合式异或。 |
|
返回数组的形状。 |
|
返回输入的逐元素符号指示。 |
|
返回数组元素的符号位。 |
所有有符号整数标量类型的抽象基类。 |
|
|
计算输入的每个元素的三角正弦值。 |
|
计算归一化 sinc 函数。 |
|
|
|
计算输入的逐元素双曲正弦值。 |
|
返回沿给定轴的元素数量。 |
|
返回数组的排序副本。 |
|
返回复数数组的排序副本。 |
|
返回 |
|
将数组拆分为子数组。 |
|
计算输入数组的逐元素非负平方根。 |
|
计算输入数组的逐元素平方。 |
|
从数组中删除一个或多个长度为 1 的轴。 |
|
沿新轴连接数组。 |
|
计算沿给定轴的标准差。 |
逐元素减去两个数组。 |
|
|
求数组沿给定轴的元素之和。 |
|
交换数组的两个轴。 |
|
从数组中提取元素。 |
|
从数组中提取元素。 |
|
计算输入中每个元素的三角正切值。 |
|
计算输入的逐元素双曲正切值。 |
|
计算两个 N 维数组的张量点积。 |
|
通过沿指定维度重复 |
|
计算输入沿给定轴的对角线之和。 |
|
使用复合梯形规则沿给定轴积分。 |
|
返回 N 维数组的转置版本。 |
|
返回一个在对角线及以下为 1,其他地方为 0 的数组。 |
|
返回数组的下三角部分。 |
|
返回大小为 |
|
返回给定数组的下三角部分的索引。 |
|
修剪输入数组的前导和/或尾随零。 |
|
返回数组的上三角部分。 |
|
返回大小为 |
|
返回给定数组的上三角部分的索引。 |
|
逐元素计算 x1 除以 x2 的结果 |
|
将输入四舍五入到最接近于零的整数。 |
|
通用函数,对数组逐元素执行操作。 |
|
|
|
uint16 类型的 JAX 标量构造函数。 |
|
uint32 类型的 JAX 标量构造函数。 |
|
uint64 类型的 JAX 标量构造函数。 |
|
uint8 类型的 JAX 标量构造函数。 |
|
计算两个一维数组的并集。 |
|
返回数组中的唯一值。 |
|
返回 x 中的唯一值,以及索引、逆索引和计数。 |
|
返回 x 中的唯一值以及计数。 |
|
返回 x 中的唯一值,以及索引、逆索引和计数。 |
|
返回 x 中的唯一值,以及索引、逆索引和计数。 |
|
解包 uint8 数组中的位。 |
|
将扁平索引转换为多维索引。 |
|
沿轴解堆叠数组。 |
所有无符号整数标量类型的抽象基类。 |
|
|
展开周期信号。 |
|
生成范德蒙矩阵。 |
|
计算沿给定轴的方差。 |
|
执行两个一维向量的共轭乘法。 |
|
执行两个批量向量的共轭乘法。 |
|
批量共轭向量-矩阵乘积。 |
|
定义一个具有广播功能的向量化函数。 |
|
将数组垂直拆分为子数组。 |
|
垂直堆叠数组。 |
|
根据条件从两个数组中选择元素。 |
|
创建一个充满零的数组。 |
|
创建一个与数组具有相同形状和 dtype 的充满零的数组。 |
jax.numpy.fft#
|
沿给定轴计算一维离散傅里叶变换。 |
|
沿给定轴计算二维离散傅里叶变换。 |
|
返回离散傅里叶变换的采样频率。 |
|
沿给定轴计算多维离散傅里叶变换。 |
|
将零频率 fft 分量移动到频谱的中心。 |
|
计算频谱具有埃尔米特对称性的数组的一维 FFT。 |
|
计算一维离散傅里叶逆变换。 |
|
计算二维逆离散傅里叶变换。 |
|
计算多维逆离散傅里叶变换。 |
|
|
|
计算频谱具有厄米对称性的数组的一维逆 FFT。 |
|
计算实值一维逆离散傅里叶变换。 |
|
计算实值二维逆离散傅里叶变换。 |
|
计算实值多维逆离散傅里叶变换。 |
|
计算实值数组的一维离散傅里叶变换。 |
|
计算实值数组的二维离散傅里叶变换。 |
|
返回离散傅里叶变换的采样频率。 |
|
计算实值数组的多维离散傅里叶变换。 |
jax.numpy.linalg#
|
计算矩阵的 Cholesky 分解。 |
|
计算矩阵的条件数。 |
|
计算两个 3D 向量的叉积。 |
|
计算数组的行列式。 |
|
提取矩阵或矩阵堆叠的对角线。 |
|
计算方阵的特征值和特征向量。 |
|
计算埃尔米特矩阵的特征值和特征向量。 |
|
计算一般矩阵的特征值。 |
|
计算埃尔米特矩阵的特征值。 |
|
返回方阵的逆矩阵。 |
|
返回线性方程的最小二乘解。 |
|
执行矩阵乘法。 |
|
计算矩阵或矩阵堆叠的范数。 |
|
将方阵提升到整数幂。 |
|
计算矩阵的秩。 |
|
转置矩阵或矩阵堆叠。 |
|
高效地计算一系列数组之间的矩阵乘积。 |
|
计算矩阵或向量的范数。 |
|
计算两个一维数组的外积。 |
|
计算矩阵的(Moore-Penrose)伪逆。 |
|
计算数组的 QR 分解。 |
|
计算数组的行列式的符号和(自然)对数。 |
|
求解线性方程组。 |
|
计算奇异值分解。 |
|
计算矩阵的奇异值。 |
|
计算两个 N 维数组的张量点积。 |
|
计算数组的张量逆。 |
|
求解张量方程 a x = b 中的 x。 |
|
计算矩阵的迹。 |
|
计算向量或批量向量的向量范数。 |
|
计算两个数组的(批量)向量共轭点积。 |
JAX 数组#
JAX Array
(以及它的别名 jax.numpy.ndarray
)是 JAX 中的核心数组对象:你可以把它看作是 JAX 中与 numpy.ndarray
等价的对象。和 numpy.ndarray
一样,大多数用户不需要手动实例化 Array
对象,而是通过 jax.numpy
函数(如 array()
、arange()
、linspace()
以及上面列出的其他函数)来创建它们。
复制和序列化#
JAX Array
对象被设计为在适当的情况下与 Python 标准库工具无缝协作。
使用内置的 copy
模块时,当 copy.copy()
或 copy.deepcopy()
遇到 Array
时,它等效于调用 copy()
方法,该方法将在与原始数组相同的设备上创建缓冲区的副本。这在跟踪/JIT 编译的代码中可以正常工作,尽管在这种情况下,编译器可能会省略复制操作。
当内置的 pickle
模块遇到 Array
时,它将通过紧凑的位表示进行序列化,其方式类似于 pickled 的 numpy.ndarray
对象。当反序列化时,结果将是在默认设备上 的新的 Array
对象。这是因为通常,序列化和反序列化可能发生在不同的运行时环境中,并且没有通用的方法将一个运行时的设备 ID 映射到另一个运行时的设备 ID。如果在跟踪/JIT 编译的代码中使用 pickle
,则会导致 ConcretizationTypeError
。
Python 数组 API 标准#
注意
在 JAX v0.4.32 之前,你必须 import jax.experimental.array_api
才能为 JAX 数组启用数组 API。在 JAX v0.4.32 之后,不再需要导入此模块,并且会引发弃用警告。
从 JAX v0.4.32 开始,jax.Array
和 jax.numpy
与 Python 数组 API 标准兼容。你可以通过 jax.Array.__array_namespace__()
访问数组 API 命名空间。
>>> def f(x):
... nx = x.__array_namespace__()
... return nx.sin(x) ** 2 + nx.cos(x) ** 2
>>> import jax.numpy as jnp
>>> x = jnp.arange(5)
>>> f(x).round()
Array([1., 1., 1., 1., 1.], dtype=float32)
JAX 在一些地方偏离了标准,主要是因为 JAX 数组是不可变的,不支持原地更新。其中一些不兼容性正在通过 array-api-compat 模块解决。
有关更多信息,请参阅 Python 数组 API 标准 文档。