jax.numpy 模块#

使用 jax.lax 中的原语实现 NumPy API。

虽然 JAX 尽量尽可能地遵循 NumPy API,但有时 JAX 无法完全遵循 NumPy。

  • 值得注意的是,由于 JAX 数组是不可变的,因此无法在 JAX 中实现就地修改数组的 NumPy API。但是,JAX 通常能够提供纯函数式替代 API。例如,代替就地数组更新 (x[i] = y),JAX 提供了一个纯索引更新函数的替代方案 x.at[i].set(y)(参见 ndarray.at)。

  • 同样地,一些 NumPy 函数通常会在可能的情况下返回数组的视图(例如 transpose()reshape())。JAX 版本的这些函数将改为返回副本,尽管当使用 jax.jit() 编译操作序列时,XLA 通常会优化掉这些副本。

  • NumPy 在将值提升到 float64 类型方面非常积极。JAX 有时在类型提升方面不那么积极(参见 类型提升语义)。

  • 一些 NumPy 例程具有数据相关的输出形状(例如 unique()nonzero())。由于 XLA 编译器要求在编译时知道数组形状,因此此类操作与 JIT 不兼容。出于这个原因,JAX 为此类函数添加了一个可选的 size 参数,该参数可以在静态中指定,以便与 JIT 一起使用。

几乎所有适用的 NumPy 函数都实现在 jax.numpy 命名空间中;它们列在下面。

ndarray.at

索引更新功能的辅助属性。

abs(x, /)

jax.numpy.absolute() 的别名。

absolute(x, /)

逐元素计算绝对值。

acos(x, /)

jax.numpy.arccos() 的别名

acosh(x, /)

jax.numpy.arccosh() 的别名

add

逐元素相加两个数组。

all(a[, axis, out, keepdims, where])

测试给定轴上所有数组元素是否都计算为 True。

allclose(a, b[, rtol, atol, equal_nan])

检查两个数组在容差范围内是否逐元素近似相等。

amax(a[, axis, out, keepdims, initial, where])

jax.numpy.max() 的别名。

amin(a[, axis, out, keepdims, initial, where])

jax.numpy.min() 的别名。

angle(z[, deg])

返回复数或复数数组的角度。

any(a[, axis, out, keepdims, where])

测试给定轴上任何数组元素是否都计算为 True。

append(arr, values[, axis])

返回一个新数组,其值附加到原始数组的末尾。

apply_along_axis(func1d, axis, arr, *args, ...)

将函数应用于给定轴上的 1-D 切片。

apply_over_axes(func, a, axes)

在多个轴上重复应用函数。

arange(start[, stop, step, dtype, device])

创建一个等间距值的数组。

arccos(x, /)

三角反余弦,逐元素计算。

arccosh(x, /)

反双曲余弦,逐元素计算。

arcsin(x, /)

反正弦,逐元素计算。

arcsinh(x, /)

反双曲正弦,逐元素计算。

arctan(x, /)

三角反正切,逐元素计算。

arctan2(x1, x2, /)

x1/x2 的逐元素反正切,正确选择象限。

arctanh(x, /)

反双曲正切,逐元素计算。

argmax(a[, axis, out, keepdims])

返回沿轴的最大值的索引。

argmin(a[, axis, out, keepdims])

返回沿轴的最小值的索引。

argpartition(a, kth[, axis])

返回部分排序数组的索引。

argsort(a[, axis, kind, order, stable, ...])

返回排序数组的索引。

argwhere(a, *[, size, fill_value])

查找非零数组元素的索引

around(a[, decimals, out])

jax.numpy.round() 的别名

array(object[, dtype, copy, order, ndmin, ...])

将对象转换为 JAX 数组。

array_equal(a1, a2[, equal_nan])

检查两个数组是否逐元素相等。

array_equiv(a1, a2)

检查两个数组是否逐元素相等。

array_repr(arr[, max_line_width, precision, ...])

返回数组的字符串表示形式。

array_split(ary, indices_or_sections[, axis])

将数组拆分为子数组。

array_str(a[, max_line_width, precision, ...])

返回数组中数据的字符串表示形式。

asarray(a[, dtype, order, copy, device])

将对象转换为 JAX 数组。

asin(x, /)

jax.numpy.arcsin() 的别名

asinh(x, /)

jax.numpy.arcsinh() 的别名

astype(x, dtype, /, *[, copy, device])

将数组转换为指定的 dtype。

atan(x, /)

jax.numpy.arctan() 的别名

atanh(x, /)

jax.numpy.arctanh() 的别名

atan2(x1, x2, /)

jax.numpy.arctan2() 的别名

atleast_1d()

将输入转换为至少具有 1 个维度的数组。

atleast_2d()

将输入转换为至少具有 2 个维度的数组。

atleast_3d()

将输入转换为至少具有 3 个维度的数组。

average()

沿指定的轴计算加权平均值。

bartlett(M)

返回 Bartlett 窗。

bincount(x[, weights, minlength, length])

计算整数数组中每个值的出现次数。

bitwise_and

逐元素计算按位 AND 操作。

bitwise_count(x, /)

计算x中每个元素的绝对值的二进制表示中1的个数。

bitwise_invert(x, /)

计算按位取反,或按位非,逐元素操作。

bitwise_left_shift(x, y, /)

将整数的位向左移动。

bitwise_not(x, /)

计算按位取反,或按位非,逐元素操作。

bitwise_or

逐元素计算按位或运算。

bitwise_right_shift(x1, x2, /)

jax.numpy.right_shift()的别名。

bitwise_xor

逐元素计算按位异或运算。

blackman(M)

返回Blackman窗。

block(arrays)

从嵌套的块列表中组装一个nd数组。

bool_

bool的别名

broadcast_arrays(*args)

将数组广播到一个共同的形状。

broadcast_shapes()

将输入形状广播到一个共同的输出形状。

broadcast_to(array, shape)

将数组广播到指定的形状。

c_

沿着最后一个轴连接切片、标量和类数组对象。

can_cast(from_, to[, casting])

如果根据转换规则可以发生数据类型之间的转换,则返回True。

cbrt(x, /)

返回数组的立方根,逐元素操作。

cdouble

complex128的别名

ceil(x, /)

向上舍入输入到最接近的整数。

character()

所有字符字符串标量类型的抽象基类。

choose(a, choices[, out, mode])

根据索引数组和要从中选择的数组列表构造一个数组。

clip([arr, min, max, a, a_min, a_max])

将数组值剪辑到指定的范围。

column_stack(tup)

按列堆叠数组。

complex_

complex128的别名

complex128(x)

complex64(x)

complexfloating()

所有由浮点数组成的复数标量类型的抽象基类。

ComplexWarning

将复数dtype转换为实数dtype时引发的警告。

compress(condition, a[, axis, size, ...])

使用布尔条件沿给定轴压缩数组。

concat(arrays, /, *[, axis])

沿着现有轴连接数组。

concatenate(arrays[, axis, dtype])

沿着现有轴连接数组。

conj(x, /)

返回复共轭,逐元素操作。

conjugate(x, /)

返回复共轭,逐元素操作。

convolve(a, v[, mode, precision, ...])

两个一维数组的卷积。

copy(a[, order])

返回数组的副本。

copysign(x1, x2, /)

x2中每个元素的符号复制到x1中相应的元素。

corrcoef(x[, y, rowvar])

返回皮尔逊积矩相关系数。

correlate(a, v[, mode, precision, ...])

两个一维数组的相关性。

cos(x, /)

计算输入每个元素的三角余弦。

cosh(x, /)

双曲余弦,逐元素操作。

count_nonzero(a[, axis, keepdims])

返回沿给定轴非零元素的数量。

cov(m[, y, rowvar, bias, ddof, fweights, ...])

根据数据和权重估计协方差矩阵。

cross(a, b[, axisa, axisb, axisc, axis])

返回两个(向量数组)的叉积。

csingle

complex64的别名

cumprod(a[, axis, dtype, out])

沿轴计算元素的累积积。

cumsum(a[, axis, dtype, out])

沿轴计算元素的累积和。

cumulative_sum(x, /, *[, axis, dtype, ...])

沿数组的轴计算累积和。

deg2rad(x, /)

将角度从度转换为弧度。

degrees(x, /)

将角度从弧度转换为度。

delete(arr, obj[, axis, assume_unique_indices])

从数组中删除条目或条目。

diag(v[, k])

返回指定的对角线或构造对角线数组。

diag_indices(n[, ndim])

返回用于访问多维数组主对角线的索引。

diag_indices_from(arr)

返回用于访问给定数组主对角线的索引。

diagflat(v[, k])

返回一个二维数组,其中扁平化的输入数组位于对角线上。

diagonal(a[, offset, axis1, axis2])

返回数组的指定对角线。

diff(a[, n, axis, prepend, append])

计算沿给定轴的数组元素之间的n阶差分。

digitize(x, bins[, right])

返回输入数组中每个值所属的bin的索引。

divide(x1, x2, /)

jax.numpy.true_divide()的别名。

divmod(x1, x2, /)

逐元素计算x1除以x2的整数商和余数

dot(a, b, *[, precision, preferred_element_type])

计算两个数组的点积。

double

float64 的别名

dsplit(ary, indices_or_sections)

沿深度方向将数组分割成子数组。

dstack(tup[, dtype])

沿深度方向堆叠数组。

dtype(dtype[, align, copy])

创建一个数据类型对象。

ediff1d(ary[, to_end, to_begin])

计算扁平化数组元素的差值。

einsum()

爱因斯坦求和

einsum_path()

在不评估 einsum 的情况下评估最佳收缩路径。

empty(shape[, dtype, device])

创建一个空的数组。

empty_like(prototype[, dtype, shape, device])

创建一个与数组具有相同形状和数据类型的空数组。

equal(x, y, /)

逐元素返回 (x1 == x2)。

exp(x, /)

计算输入的逐元素指数。

exp2(x, /)

计算输入的逐元素以 2 为底的指数。

expand_dims(a, axis)

在数组中插入长度为 1 的维度

expm1(x, /)

计算输入每个元素的 exp(x)-1

extract(condition, arr, *[, size, fill_value])

返回满足条件的数组元素。

eye(N[, M, k, dtype, device])

创建一个方阵或矩形单位矩阵

fabs(x, /)

计算实值输入的逐元素绝对值。

fill_diagonal(a, val[, wrap, inplace])

返回一个数组的副本,其中对角线被覆盖。

finfo(dtype)

浮点类型的机器限制。

fix(x[, out])

将输入四舍五入到最接近零的整数。

flatnonzero(a, *[, size, fill_value])

返回扁平化数组中非零元素的索引

flexible()

所有没有预定义长度的标量类型的抽象基类。

flip(m[, axis])

沿给定轴反转数组元素的顺序。

fliplr(m)

沿轴 1 反转数组元素的顺序。

flipud(m)

沿轴 0 反转数组元素的顺序。

float_

float64 的别名

float_power(x, y, /)

计算 y 的逐元素以 x 为底的指数。

float16(x)

float32(x)

float64(x)

floating()

所有浮点标量类型的抽象基类。

floor(x, /)

将输入向下舍入到最接近的整数。

floor_divide(x1, x2, /)

逐元素计算 x1 除以 x2 的向下取整除法

fmax(x1, x2)

返回输入数组的逐元素最大值。

fmin(x1, x2)

返回输入数组的逐元素最小值。

fmod(x1, x2, /)

返回除法的逐元素余数。

frexp(x, /)

将 x 的元素分解成尾数和以 2 为底的指数。

frombuffer(buffer[, dtype, count, offset])

将缓冲区转换为一维 JAX 数组。

fromfile(*args, **kwargs)

jnp.fromfile 的未实现 JAX 包装器。

fromfunction(function, shape, *[, dtype])

通过在每个坐标上执行函数来构造数组。

fromiter(*args, **kwargs)

jnp.fromiter 的未实现 JAX 包装器。

frompyfunc(func, /, nin, nout, *[, identity])

从任意兼容 JAX 的标量函数创建 JAX ufunc。

fromstring(string[, dtype, count])

将文本字符串转换为一维 JAX 数组。

from_dlpack(x, /, *[, device, copy])

通过 DLPack 构造 JAX 数组。

full(shape, fill_value[, dtype, device])

创建一个充满指定值的数组。

full_like(a, fill_value[, dtype, shape, device])

创建一个充满指定值的数组,其形状和数据类型与数组相同。

gcd(x1, x2)

计算两个数组的最大公约数。

generic()

numpy 标量类型的基类。

geomspace(start, stop[, num, endpoint, ...])

返回对数刻度上均匀间隔的数字(几何级数)。

get_printoptions()

返回当前打印选项。

gradient(f, *varargs[, axis, edge_order])

返回 N 维数组的梯度。

greater(x, y, /)

返回 x > y 的逐元素真值。

greater_equal(x, y, /)

返回 x >= y 的逐元素真值。

hamming(M)

返回汉明窗。

hanning(M)

返回汉宁窗。

heaviside(x1, x2, /)

计算海维赛德阶跃函数。

histogram(a[, bins, range, weights, density])

计算数据集的直方图。

histogram_bin_edges(a[, bins, range, weights])

仅计算 histogram 使用的 bin 边缘的函数

histogram2d(x, y[, bins, range, weights, ...])

计算两个数据样本的二维直方图。

histogramdd(sample[, bins, range, weights, ...])

计算某些数据的多分量直方图。

hsplit(ary, indices_or_sections)

水平分割数组为子数组。

hstack(tup[, dtype])

水平堆叠数组。

hypot(x1, x2, /)

给定直角三角形的“直角边”,返回其斜边。

i0

第一类修正贝塞尔函数,阶数为0。

identity(n[, dtype])

创建一个正方形单位矩阵

iinfo(int_type)

imag(val, /)

返回复数参数的虚部。

index_exp

一种更友好的构建数组索引元组的方式。

indices()

返回表示网格索引的数组。

inexact()

所有数值标量类型的抽象基类,其值范围(可能)以不精确的方式表示,例如浮点数。

inner(a, b, *[, precision, ...])

计算两个数组的内积。

insert(arr, obj, values[, axis])

在给定轴上的给定索引之前插入值。

int_

int64 的别名

int16(x)

int32(x)

int64(x)

int8(x)

integer()

所有整数标量类型的抽象基类。

interp(x, xp, fp[, left, right, period])

单维线性插值,适用于单调递增的样本点。

intersect1d(ar1, ar2[, assume_unique, ...])

计算两个一维数组的交集。

invert(x, /)

计算按位取反,或按位非,逐元素操作。

isclose(a, b[, rtol, atol, equal_nan])

检查两个数组的元素在容差范围内是否近似相等。

iscomplex(x)

返回布尔数组,显示输入在何处为复数。

iscomplexobj(x)

检查输入是否为复数或包含复数元素的数组。

isdtype(dtype, kind)

返回一个布尔值,指示提供的 dtype 是否为指定类型。

isfinite(x, /)

逐元素测试有限性(不是无穷大,也不是非数字)。

isin(element, test_elements[, ...])

确定element中的元素是否出现在test_elements中。

isinf(x, /)

逐元素测试正无穷或负无穷。

isnan(x, /)

逐元素测试 NaN 并将结果作为布尔数组返回。

isneginf(x, /[, out])

逐元素测试负无穷,将结果作为布尔数组返回。

isposinf(x, /[, out])

逐元素测试负无穷,将结果作为布尔数组返回。

isreal(x)

返回布尔数组,显示输入在何处为实数。

isrealobj(x)

检查输入是否不是复数或包含复数元素的数组。

isscalar(element)

如果element的类型是标量类型,则返回 True。

issubdtype(arg1, arg2)

如果第一个参数是类型层次结构中较低/相等的类型代码,则返回 True。

iterable(y)

检查对象是否可以迭代。

ix_(*args)

从 N 个一维序列返回一个多维网格(开放网格)。

kaiser(M, beta)

返回凯泽窗。

kron(a, b)

计算两个输入数组的克罗内克积。

lcm(x1, x2)

计算两个数组的最小公倍数。

ldexp(x1, x2, /)

逐元素返回 x1 * 2**x2。

left_shift(x, y, /)

将整数的位向左移动。

less(x, y, /)

返回x < y的逐元素真值。

less_equal(x, y, /)

返回x <= y的逐元素真值。

lexsort(keys[, axis])

使用一系列键执行间接稳定排序。

linspace()

在区间内返回等间距的数字。

load(*args, **kwargs)

.npy.npz或pickle文件中加载数组或pickle对象。

log(x, /)

计算输入的逐元素自然对数。

log10(x, /)

逐元素计算 x 的以 10 为底的对数

log1p(x, /)

计算一加输入的逐元素对数,log(x+1)

log2(x, /)

逐元素计算x的以 2 为底的对数。

logaddexp(x1, x2, /)

计算log(exp(x1) + exp(x2)),避免溢出。

logaddexp2(x1, x2, /)

以 2 为底计算输入指数和的对数,避免溢出。

logical_and

逐元素计算逻辑 AND 运算。

logical_not(x, /)

逐元素计算 NOT x 的真值。

logical_or

逐元素计算逻辑 OR 运算。

逻辑异或

逐元素计算逻辑异或运算。

logspace(start, stop[, num, endpoint, base, ...])

返回对数刻度上均匀分布的数字。

mask_indices(*args, **kwargs)

给定一个掩码函数,返回访问 (n, n) 数组的索引。

matmul(a, b, *[, precision, ...])

执行矩阵乘法。

matrix_transpose(x, /)

转置数组的最后两个维度。

max(a[, axis, out, keepdims, initial, where])

沿给定轴返回数组元素的最大值。

maximum(x, y, /)

返回输入数组的逐元素最大值。

mean(a[, axis, dtype, out, keepdims, where])

沿给定轴返回数组元素的平均值。

median(a[, axis, out, overwrite_input, keepdims])

沿给定轴返回数组元素的中位数。

meshgrid(*xi[, copy, sparse, indexing])

从坐标向量返回坐标矩阵的元组。

mgrid

返回密集的多维“网格”。

min(a[, axis, out, keepdims, initial, where])

沿给定轴返回数组元素的最小值。

minimum(x, y, /)

返回输入数组的逐元素最小值。

mod(x1, x2, /)

返回除法的逐元素余数。

modf(x, /[, out])

逐元素返回数组的小数部分和整数部分。

moveaxis(a, source, destination)

将数组轴移动到新位置

乘法

逐元素将两个数组相乘。

nan_to_num(x[, copy, nan, posinf, neginf])

将 NaN 替换为零,将无穷大替换为大的有限数字(默认

nanargmax(a[, axis, out, keepdims])

返回指定轴上最大值的索引,忽略

nanargmin(a[, axis, out, keepdims])

返回指定轴上最小值的索引,忽略

nancumprod(a[, axis, dtype, out])

沿轴计算元素的累积积,忽略 NaN 值。

nancumsum(a[, axis, dtype, out])

沿轴计算元素的累积和,忽略 NaN 值。

nanmax(a[, axis, out, keepdims, initial, where])

沿给定轴返回数组元素的最大值,忽略 NaN。

nanmean(a[, axis, dtype, out, keepdims, where])

沿给定轴返回数组元素的平均值,忽略 NaN。

nanmedian(a[, axis, out, overwrite_input, ...])

沿给定轴返回数组元素的中位数,忽略 NaN。

nanmin(a[, axis, out, keepdims, initial, where])

沿给定轴返回数组元素的最小值,忽略 NaN。

nanpercentile(a, q[, axis, out, ...])

计算数据沿指定轴的百分位数,忽略 NaN 值。

nanprod(a[, axis, dtype, out, keepdims, ...])

沿给定轴返回数组元素的乘积,忽略 NaN。

nanquantile(a, q[, axis, out, ...])

计算数据沿指定轴的分位数,忽略 NaN。

nanstd(a[, axis, dtype, out, ddof, ...])

沿给定轴计算标准差,忽略 NaN。

nansum(a[, axis, dtype, out, keepdims, ...])

沿给定轴返回数组元素的和,忽略 NaN。

nanvar(a[, axis, dtype, out, ddof, ...])

沿给定轴计算数组元素的方差,忽略 NaN。

ndarray

Array 的别名

ndim(a)

返回数组的维度数。

negative(x, /)

返回输入的逐元素负值。

nextafter(x, y, /)

返回x朝向y的下一个浮点值。

nonzero(a, *[, size, fill_value])

返回数组中非零元素的索引。

not_equal(x, y, /)

返回 (x1 != x2) 的逐元素结果。

数字()

所有数值标量类型的抽象基类。

object_

任何 Python 对象。

ogrid

返回开放的多维“网格”。

ones(shape[, dtype, device])

创建一个全为一的数组。

ones_like(a[, dtype, shape, device])

创建一个与数组具有相同形状和数据类型的全为一的数组。

outer(a, b[, out])

计算两个数组的外积。

packbits(a[, axis, bitorder])

将二值数组的元素打包到 uint8 数组的位中。

pad(array, pad_width[, mode])

填充数组。

partition(a, kth[, axis])

返回数组的部分排序副本。

percentile(a, q[, axis, out, ...])

计算数据沿指定轴的百分位数。

permute_dims(a, /, axes)

置换数组的轴/维度。

piecewise(x, condlist, funclist, *args, **kw)

评估跨域定义的分段函数。

place(arr, mask, vals, *[, inplace])

基于掩码更新数组元素。

poly(seq_of_zeros)

返回给定根序列的多项式的系数。

polyadd(a1, a2)

返回两个多项式的和。

polyder(p[, m])

返回指定阶多项式导数的系数。

polydiv(u, v, *[, trim_leading_zeros])

返回多项式除法的商和余数。

polyfit(x, y, deg[, rcond, full, w, cov])

对数据进行最小二乘多项式拟合。

polyint(p[, m, k])

返回指定阶多项式积分的系数。

polymul(a1, a2, *[, trim_leading_zeros])

返回两个多项式的乘积。

polysub(a1, a2)

返回两个多项式的差。

polyval(p, x, *[, unroll])

在特定值处计算多项式的值。

positive(x, /)

返回输入的逐元素正值。

pow(x1, x2, /)

jax.numpy.power() 的别名

power(x1, x2, /)

计算 x1x2 次幂的逐元素运算。

printoptions(*args, **kwargs)

用于设置打印选项的上下文管理器。

prod(a[, axis, dtype, out, keepdims, ...])

返回沿给定轴的数组元素的乘积。

promote_types(a, b)

返回二元运算应将其参数转换为的类型。

ptp(a[, axis, out, keepdims])

返回沿给定轴的峰峰值范围。

put(a, ind, v[, mode, inplace])

将元素放入给定索引处的数组中。

quantile(a, q[, axis, out, overwrite_input, ...])

沿指定轴计算数据的分位数。

r_

沿第一个轴连接切片、标量和类似数组的对象。

rad2deg(x, /)

将角度从弧度转换为度。

radians(x, /)

将角度从度转换为弧度。

ravel(a[, order])

将数组展平为一维形状。

ravel_multi_index(multi_index, dims[, mode, ...])

将多维索引转换为扁平索引。

real(val, /)

返回复数参数的实部。

reciprocal(x, /)

返回参数的倒数,逐元素计算。

remainder(x1, x2, /)

返回除法的逐元素余数。

repeat(a, repeats[, axis, total_repeat_length])

从重复的元素构造一个数组。

reshape(a[, shape, order, newshape, copy])

返回数组的重新整形副本。

resize(a, new_shape)

返回具有指定形状的新数组。

result_type(*args)

返回应用 NumPy

right_shift(x1, x2, /)

x1 的位向右移动 x2 指定的数量。

rint(x, /)

将 x 的元素四舍五入到最接近的整数

roll(a, shift[, axis])

沿指定轴滚动数组的元素。

rollaxis(a, axis[, start])

将指定的轴滚动到给定位置。

roots(p, *[, strip_zeros])

给定系数 p 返回多项式的根。

rot90(m[, k, axes])

将数组在由轴指定的平面中逆时针旋转 90 度。

round(a[, decimals, out])

将输入均匀四舍五入到给定的十进制位数。

round_(a[, decimals, out])

将输入均匀四舍五入到给定的十进制位数。

s_

一种更友好的构建数组索引元组的方式。

save(file, arr[, allow_pickle, fix_imports])

将数组保存到 NumPy .npy 格式的二进制文件中。

savez(file, *args, **kwds)

将多个数组保存到单个未压缩的 .npz 格式文件中。

searchsorted(a, v[, side, sorter, method])

在已排序的数组中执行二分查找。

select(condlist, choicelist[, default])

根据一系列条件选择值。

set_printoptions([precision, threshold, ...])

设置打印选项。

setdiff1d(ar1, ar2[, assume_unique, size, ...])

计算两个一维数组的集合差。

setxor1d(ar1, ar2[, assume_unique, size, ...])

计算两个数组中元素的集合异或。

shape(a)

返回数组的形状。

sign(x, /)

返回输入的符号的逐元素指示。

signbit(x, /)

在符号位设置(小于零)的位置返回逐元素的 True。

signedinteger()

所有带符号整数标量类型的抽象基类。

sin(x, /)

计算输入的每个元素的三角正弦。

sinc(x, /)

返回归一化的 sinc 函数。

single

float32 的别名

sinh(x, /)

双曲正弦,逐元素计算。

size(a[, axis])

返回沿给定轴的元素数量。

sort(a[, axis, kind, order, stable, descending])

返回数组的已排序副本。

sort_complex(a)

根据实部,然后根据虚部对复杂数组进行排序。

split(ary, indices_or_sections[, axis])

将数组拆分为子数组。

sqrt(x, /)

逐元素返回数组的非负平方根。

square(x, /)

返回输入的逐元素平方。

squeeze(a[, axis])

从数组中删除一个或多个长度为 1 的轴。

stack(arrays[, axis, out, dtype])

沿着新轴连接数组。

std(a[, axis, dtype, out, ddof, keepdims, ...])

计算沿给定轴的标准差。

subtract(x, y, /)

逐元素减去参数。

sum(a[, axis, dtype, out, keepdims, ...])

沿给定轴计算数组元素的和。

swapaxes(a, axis1, axis2)

交换数组的两个轴。

take(a, indices[, axis, out, mode, ...])

从数组中获取元素。

take_along_axis(arr, indices, axis[, mode, ...])

从数组中获取元素。

tan(x, /)

计算输入每个元素的三角正切。

tanh(x, /)

逐元素计算双曲正切。

tensordot(a, b[, axes, precision, ...])

计算两个 N 维数组的张量点积。

tile(A, reps)

通过重复 A reps 指定的次数来构造一个数组。

trace(a[, offset, axis1, axis2, dtype, out])

返回数组对角线元素的和。

trapezoid(y[, x, dx, axis])

使用复合梯形规则沿给定轴积分。

transpose(a[, axes])

返回 N 维数组的转置版本。

tri(N[, M, k, dtype])

返回一个数组,在对角线及其下方为 1,其他位置为 0。

tril(m[, k])

返回数组的下三角部分。

tril_indices(n[, k, m])

返回大小为 (n, m) 的数组下三角部分的索引。

tril_indices_from(arr[, k])

返回给定数组下三角部分的索引。

trim_zeros(filt[, trim])

修剪输入数组开头和/或结尾的零。

triu(m[, k])

返回数组的上三角部分。

triu_indices(n[, k, m])

返回大小为 (n, m) 的数组上三角部分的索引。

triu_indices_from(arr[, k])

返回给定数组上三角部分的索引。

true_divide(x1, x2, /)

逐元素计算 x1 除以 x2 的结果。

trunc(x)

将输入四舍五入到最接近零的整数。

ufunc(func, /, nin, nout, *[, name, nargs, ...])

通用函数,对数组逐元素进行运算。

uint

uint64 的别名

uint16(x)

uint32(x)

uint64(x)

uint8(x)

union1d(ar1, ar2, *[, size, fill_value])

计算两个一维数组的并集。

unique(ar[, return_index, return_inverse, ...])

返回数组中唯一的值。

unique_all(x, /, *[, size, fill_value])

返回 x 中唯一的值,以及索引、反向索引和计数。

unique_counts(x, /, *[, size, fill_value])

返回 x 中唯一的值以及计数。

unique_inverse(x, /, *[, size, fill_value])

返回 x 中唯一的值,以及索引、反向索引和计数。

unique_values(x, /, *[, size, fill_value])

返回 x 中唯一的值,以及索引、反向索引和计数。

unpackbits(a[, axis, count, bitorder])

将 uint8 数组的元素解包到二进制值的输出数组中。

unravel_index(indices, shape)

将扁平索引转换为多维索引。

unstack(x, /, *[, axis])

沿轴解开数组。

unsignedinteger()

所有无符号整数标量类型的抽象基类。

unwrap(p[, discont, axis, period])

通过取相对于周期的较大增量的补码来解开。

vander(x[, N, increasing])

生成范德蒙矩阵。

var(a[, axis, dtype, out, ddof, keepdims, ...])

计算沿给定轴的方差。

vdot(a, b, *[, precision, ...])

执行两个一维向量的共轭乘法。

vecdot(x1, x2, /, *[, axis, precision, ...])

执行两个批处理向量的共轭乘法。

vectorize(pyfunc, *[, excluded, signature])

定义具有广播功能的向量化函数。

vsplit(ary, indices_or_sections)

垂直拆分数组为子数组。

vstack(tup[, dtype])

垂直堆叠数组。

where()

根据条件从两个数组中选择元素。

zeros(shape[, dtype, device])

创建一个全零数组。

zeros_like(a[, dtype, shape, device])

创建一个与给定数组形状和数据类型相同的全零数组。

jax.numpy.fft#

fft(a[, n, axis, norm])

沿给定轴计算一维离散傅里叶变换。

fft2(a[, s, axes, norm])

沿给定轴计算二维离散傅里叶变换。

fftfreq(n[, d, dtype, device])

返回离散傅里叶变换的采样频率。

fftn(a[, s, axes, norm])

沿给定轴计算多维离散傅里叶变换。

fftshift(x[, axes])

将零频率分量移到频谱的中心。

hfft(a[, n, axis, norm])

计算一个数组的一维 FFT,该数组的频谱具有厄米对称性。

ifft(a[, n, axis, norm])

计算一维逆离散傅里叶变换。

ifft2(a[, s, axes, norm])

计算二维逆离散傅里叶变换。

ifftn(a[, s, axes, norm])

计算多维逆离散傅里叶变换。

ifftshift(x[, axes])

fftshift的逆运算。

ihfft(a[, n, axis, norm])

计算一个数组的一维逆 FFT,该数组的频谱具有厄米对称性。

irfft(a[, n, axis, norm])

计算实值一维逆离散傅里叶变换。

irfft2(a[, s, axes, norm])

计算实值二维逆离散傅里叶变换。

irfftn(a[, s, axes, norm])

计算实值多维逆离散傅里叶变换。

rfft(a[, n, axis, norm])

计算实值数组的一维离散傅里叶变换。

rfft2(a[, s, axes, norm])

计算实值数组的二维离散傅里叶变换。

rfftfreq(n[, d, dtype, device])

返回离散傅里叶变换的采样频率。

rfftn(a[, s, axes, norm])

计算实值数组的多维离散傅里叶变换。

jax.numpy.linalg#

cholesky(a, *[, upper])

计算矩阵的Cholesky分解。

cond(x[, p])

计算矩阵的条件数。

cross(x1, x2, /, *[, axis])

计算两个3D向量的叉积。

det(a)

计算数组的行列式。

diagonal(x, /, *[, offset])

提取矩阵或矩阵堆栈的对角线。

eig(a)

计算方阵的特征值和特征向量。

eigh(a[, UPLO, symmetrize_input])

计算厄米矩阵的特征值和特征向量。

eigvals(a)

计算一般矩阵的特征值。

eigvalsh(a[, UPLO])

计算厄米矩阵的特征值。

inv(a)

返回方阵的逆矩阵。

lstsq(a, b[, rcond, numpy_resid])

返回线性方程组的最小二乘解。

matmul(x1, x2, /, *[, precision, ...])

执行矩阵乘法。

matrix_norm(x, /, *[, keepdims, ord])

计算矩阵或矩阵堆栈的范数。

matrix_power(a, n)

将方阵提升到整数次幂。

matrix_rank(M[, rtol, tol])

计算矩阵的秩。

matrix_transpose(x, /)

转置矩阵或矩阵堆栈。

multi_dot(arrays, *[, precision])

有效地计算一系列数组之间的矩阵乘积。

norm(x[, ord, axis, keepdims])

计算矩阵或向量的范数。

outer(x1, x2, /)

计算两个一维数组的外积。

pinv(a[, rtol, hermitian, rcond])

计算矩阵的(Moore-Penrose)伪逆。

qr()

计算数组的QR分解。

slogdet(a, *[, method])

计算数组的行列式的符号和(自然)对数。

solve(a, b)

解线性方程组。

svd()

计算奇异值分解。

svdvals(x, /)

计算矩阵的奇异值。

tensordot(x1, x2, /, *[, axes, precision, ...])

计算两个 N 维数组的张量点积。

tensorinv(a[, ind])

计算数组的张量逆。

tensorsolve(a, b[, axes])

求解张量方程 a x = b 中的 x。

trace(x, /, *[, offset, dtype])

计算矩阵的迹。

vector_norm(x, /, *[, axis, keepdims, ord])

计算向量或向量批次的向量范数。

vecdot(x1, x2, /, *[, axis, precision, ...])

计算两个数组的(批处理)向量共轭点积。

JAX 数组#

JAX 的 Array(及其别名 jax.numpy.ndarray)是 JAX 中的核心数组对象:您可以将其视为 JAX 中等效于 numpy.ndarray 的对象。与 numpy.ndarray 一样,大多数用户无需手动实例化 Array 对象,而是通过 jax.numpy 中的函数(如 array()arange()linspace() 等,如上所述)来创建它们。

复制和序列化#

JAX Array 对象旨在在适当的情况下与 Python 标准库工具无缝协作。

使用内置的 copy 模块,当 copy.copy()copy.deepcopy() 遇到 Array 时,它等效于调用 copy() 方法,这将在与原始数组相同的设备上创建缓冲区的副本。这将在跟踪/JIT 编译的代码中正常工作,尽管在这种情况下,编译器可能会省略复制操作。

当内置的 pickle 模块遇到 Array 时,它将通过类似于腌制 numpy.ndarray 对象的紧凑位表示形式进行序列化。取消腌制后,结果将是一个新的 Array 对象,**位于默认设备上**。这是因为通常,腌制和取消腌制可能发生在不同的运行时环境中,并且没有普遍的方法将一个运行时的设备 ID 映射到另一个运行时的设备 ID。如果在跟踪/JIT 编译的代码中使用 pickle,则会导致 ConcretizationTypeError

Python 数组 API 标准#

注意

在 JAX v0.4.32 之前,您必须 import jax.experimental.array_api 才能为 JAX 数组启用数组 API。在 JAX v0.4.32 之后,不再需要导入此模块,并且会引发弃用警告。

从 JAX v0.4.32 开始,jax.Arrayjax.numpyPython 数组 API 标准 兼容。您可以通过 jax.Array.__array_namespace__() 访问数组 API 命名空间。

>>> def f(x):
...   nx = x.__array_namespace__()
...   return nx.sin(x) ** 2 + nx.cos(x) ** 2

>>> import jax.numpy as jnp
>>> x = jnp.arange(5)
>>> f(x).round()
Array([1., 1., 1., 1., 1.], dtype=float32)

JAX 在少数几个地方偏离了标准,主要是因为 JAX 数组是不可变的,不支持就地更新。其中一些不兼容性正在通过 array-api-compat 模块得到解决。

有关更多信息,请参阅 Python 数组 API 标准 文档。