jax.numpy 模块#

使用 jax.lax 中的原语实现 NumPy API。

虽然 JAX 尝试尽可能紧密地遵循 NumPy API,但有时 JAX 无法完全遵循 NumPy。

  • 值得注意的是,由于 JAX 数组是不可变的,NumPy 中会原地修改数组的 API 无法在 JAX 中实现。然而,JAX 通常能够提供一种纯函数式的替代 API。例如,JAX 提供了一个纯粹的索引更新函数 x.at[i].set(y),而不是原地数组更新 (x[i] = y)(请参阅 ndarray.at)。

  • 类似地,一些 NumPy 函数在可能的情况下会返回数组的视图(例如 transpose()reshape())。JAX 的这些函数版本将返回副本,但是当使用 jax.jit() 编译操作序列时,XLA 通常会优化掉这些副本。

  • NumPy 在将值提升为 float64 类型方面非常积极。JAX 在类型提升方面有时没有那么积极(请参阅 类型提升语义)。

  • 一些 NumPy 例程具有数据相关的输出形状(例如 unique()nonzero())。由于 XLA 编译器要求数组形状在编译时已知,因此此类操作与 JIT 不兼容。因此,JAX 为此类函数添加了一个可选的 size 参数,可以在静态指定此参数的情况下将它们与 JIT 一起使用。

几乎所有适用的 NumPy 函数都在 jax.numpy 命名空间中实现;它们在下面列出。

ndarray.at

用于索引更新功能的辅助属性。

abs(x, /)

jax.numpy.absolute() 的别名。

absolute(x, /)

计算逐元素的绝对值。

acos(x, /)

jax.numpy.arccos() 的别名

acosh(x, /)

jax.numpy.arccosh() 的别名

add

逐元素地相加两个数组。

all(a[, axis, out, keepdims, where])

测试给定轴上的所有数组元素是否都计算为 True。

allclose(a, b[, rtol, atol, equal_nan])

检查两个数组在容差范围内是否逐元素近似相等。

amax(a[, axis, out, keepdims, initial, where])

jax.numpy.max() 的别名。

amin(a[, axis, out, keepdims, initial, where])

jax.numpy.min() 的别名。

angle(z[, deg])

返回复数值数字或数组的角度。

any(a[, axis, out, keepdims, where])

测试给定轴上的任何数组元素是否计算为 True。

append(arr, values[, axis])

返回一个新数组,其中值附加到原始数组的末尾。

apply_along_axis(func1d, axis, arr, *args, ...)

沿轴将函数应用于一维数组切片。

apply_over_axes(func, a, axes)

在指定的轴上重复应用函数。

arange(start[, stop, step, dtype, device])

创建均匀间隔的值的数组。

arccos(x, /)

计算输入的三角余弦的逐元素逆。

arccosh(x, /)

计算输入的双曲余弦的逐元素逆。

arcsin(x, /)

计算输入的三角正弦的逐元素逆。

arcsinh(x, /)

计算输入的双曲正弦的逐元素逆。

arctan(x, /)

计算输入的三角正切的逐元素逆。

arctan2(x1, x2, /)

计算 x1/x2 的反正切,选择正确的象限。

arctanh(x, /)

计算输入的双曲正切的逐元素逆。

argmax(a[, axis, out, keepdims])

返回数组最大值的索引。

argmin(a[, axis, out, keepdims])

返回数组最小值的索引。

argpartition(a, kth[, axis])

返回部分排序数组的索引。

argsort(a[, axis, kind, order, stable, ...])

返回对数组排序的索引。

argwhere(a, *[, size, fill_value])

查找非零数组元素的索引

around(a[, decimals, out])

jax.numpy.round() 的别名

array(object[, dtype, copy, order, ndmin, ...])

将对象转换为 JAX 数组。

array_equal(a1, a2[, equal_nan])

检查两个数组是否逐元素相等。

array_equiv(a1, a2)

检查两个数组是否逐元素相等。

array_repr(arr[, max_line_width, precision, ...])

返回数组的字符串表示形式。

array_split(ary, indices_or_sections[, axis])

将数组拆分为子数组。

array_str(a[, max_line_width, precision, ...])

返回数组中数据的字符串表示形式。

asarray(a[, dtype, order, copy, device])

将对象转换为 JAX 数组。

asin(x, /)

jax.numpy.arcsin() 的别名

asinh(x, /)

jax.numpy.arcsinh() 的别名

astype(x, dtype, /, *[, copy, device])

将数组转换为指定的 dtype。

atan(x, /)

jax.numpy.arctan() 的别名

atanh(x, /)

jax.numpy.arctanh() 的别名

atan2(x1, x2, /)

jax.numpy.arctan2() 的别名

atleast_1d(*arys)

将输入转换为至少一维的数组。

atleast_2d(*arys)

将输入转换为至少二维的数组。

atleast_3d(*arys)

将输入转换为至少三维的数组。

average(a[, axis, weights, returned, keepdims])

计算加权平均值。

bartlett(M)

返回大小为 M 的 Bartlett 窗口。

bincount(x[, weights, minlength, length])

计算整数数组中每个值的出现次数。

bitwise_and

逐元素计算按位与运算。

bitwise_count(x, /)

计算 x 的每个元素的绝对值的二进制表示形式中 1 的位数。

bitwise_invert(x, /)

jax.numpy.invert() 的别名。

bitwise_left_shift(x, y, /)

jax.numpy.left_shift() 的别名。

bitwise_not(x, /)

jax.numpy.invert() 的别名。

bitwise_or

逐元素计算按位或运算。

bitwise_right_shift(x1, x2, /)

jax.numpy.right_shift() 的别名。

bitwise_xor

逐元素计算按位异或运算。

blackman(M)

返回大小为 M 的 Blackman 窗口。

block(arrays)

从块列表创建数组。

bool_

bool 的别名

broadcast_arrays(*args)

将数组广播到共同的形状。

broadcast_shapes(*shapes)

将输入形状广播到共同的输出形状。

broadcast_to(array, shape)

将数组广播到指定的形状。

c_

沿最后一个轴连接切片、标量和类数组对象。

can_cast(from_, to[, casting])

如果根据转换规则可以进行数据类型之间的转换,则返回 True。

cbrt(x, /)

计算输入数组的逐元素立方根。

cdouble

complex128 的别名

ceil(x, /)

将输入向上舍入到最接近的整数。

character()

所有字符串标量类型的抽象基类。

choose(a, choices[, out, mode])

通过堆叠选择数组的切片来构造数组。

clip([arr, min, max, a, a_min, a_max])

将数组值剪切到指定范围。

column_stack(tup)

按列堆叠数组。

complex_

complex128 的别名

complex128(x)

类型为 complex128 的 JAX 标量构造函数。

complex64(x)

类型为 complex64 的 JAX 标量构造函数。

complexfloating()

所有由浮点数组成的复数标量类型的抽象基类。

ComplexWarning

将复数 dtype 转换为实数 dtype 时引发的警告。

compress(condition, a[, axis, size, ...])

使用布尔条件沿给定轴压缩数组。

concat(arrays, /, *[, axis])

沿现有轴连接数组。

concatenate(arrays[, axis, dtype])

沿现有轴连接数组。

conj(x, /)

jax.numpy.conjugate() 的别名

conjugate(x, /)

返回输入的逐元素复共轭。

convolve(a, v[, mode, precision, ...])

两个一维数组的卷积。

copy(a[, order])

返回数组的副本。

copysign(x1, x2, /)

x2 中每个元素的符号复制到 x1 中相应的元素。

corrcoef(x[, y, rowvar])

计算皮尔逊相关系数。

correlate(a, v[, mode, precision, ...])

两个一维数组的相关性。

cos(x, /)

计算输入的每个元素的三角余弦值。

cosh(x, /)

计算输入的逐元素双曲余弦。

count_nonzero(a[, axis, keepdims])

返回给定轴上非零元素的数量。

cov(m[, y, rowvar, bias, ddof, fweights, ...])

估计加权样本协方差。

cross(a, b[, axisa, axisb, axisc, axis])

计算两个数组的(批处理)叉积。

csingle

complex64 的别名

cumprod(a[, axis, dtype, out])

沿轴的元素的累积乘积。

cumsum(a[, axis, dtype, out])

沿轴的元素的累积和。

cumulative_prod(x, /, *[, axis, dtype, ...])

沿数组轴的累积乘积。

cumulative_sum(x, /, *[, axis, dtype, ...])

沿数组轴计算累积和。

deg2rad(x, /)

将角度从度数转换为弧度。

degrees(x, /)

jax.numpy.rad2deg()的别名

delete(arr, obj[, axis, assume_unique_indices])

从数组中删除一个或多个条目。

diag(v[, k])

返回指定的对角线或构造一个对角线数组。

diag_indices(n[, ndim])

返回用于访问多维数组主对角线的索引。

diag_indices_from(arr)

返回用于访问给定数组主对角线的索引。

diagflat(v[, k])

返回一个 2-D 数组,其对角线上铺设了展平的输入数组。

diagonal(a[, offset, axis1, axis2])

返回数组的指定对角线。

diff(a[, n, axis, prepend, append])

计算沿给定轴的数组元素之间的 n 阶差分。

digitize(x, bins[, right, method])

将数组转换为 bin 索引。

divide(x1, x2, /)

jax.numpy.true_divide()的别名。

divmod(x1, x2, /)

按元素方式计算 x1 除以 x2 的整数商和余数

dot(a, b, *[, precision, preferred_element_type])

计算两个数组的点积。

double

float64的别名

dsplit(ary, indices_or_sections)

沿深度方向将数组拆分为子数组。

dstack(tup[, dtype])

沿深度方向堆叠数组。

dtype(dtype[, align, copy])

创建数据类型对象。

ediff1d(ary[, to_end, to_begin])

计算展平数组的元素的差值。

einsum(subscripts, /, *operands[, out, ...])

爱因斯坦求和

einsum_path(subscripts, /, *operands[, optimize])

计算最优收缩路径,而不计算 einsum。

empty(shape[, dtype, device])

创建一个空数组。

empty_like(prototype[, dtype, shape, device])

创建一个与数组具有相同形状和数据类型的空数组。

equal(x, y, /)

返回 x == y 的按元素真值。

exp(x, /)

计算输入的按元素指数。

exp2(x, /)

计算输入的按元素 2 为底的指数。

expand_dims(a, axis)

将长度为 1 的维度插入数组

expm1(x, /)

计算输入中每个元素的 exp(x)-1

extract(condition, arr, *[, size, fill_value])

返回满足条件的数组的元素。

eye(N[, M, k, dtype, device])

创建方形或矩形单位矩阵

fabs(x, /)

计算实值输入的按元素绝对值。

fill_diagonal(a, val[, wrap, inplace])

返回一个数组的副本,其中对角线被覆盖。

finfo(dtype)

浮点类型的机器限制。

fix(x[, out])

将输入四舍五入到最接近于零的整数。

flatnonzero(a, *[, size, fill_value])

返回展平数组中非零元素的索引

flexible()

所有没有预定义长度的标量类型的抽象基类。

flip(m[, axis])

沿给定轴反转数组元素的顺序。

fliplr(m)

沿轴 1 反转数组元素的顺序。

flipud(m)

沿轴 0 反转数组元素的顺序。

float_

float64的别名

float_power(x, y, /)

计算 y 的按元素底数 x 指数。

float16(x)

float16 类型的 JAX 标量构造函数。

float32(x)

float32 类型的 JAX 标量构造函数。

float64(x)

float64 类型的 JAX 标量构造函数。

floating()

所有浮点标量类型的抽象基类。

floor(x, /)

将输入四舍五入到最接近的向下整数。

floor_divide(x1, x2, /)

按元素方式计算 x1 除以 x2 的向下除法

fmax(x1, x2)

返回输入数组的按元素最大值。

fmin(x1, x2)

返回输入数组的按元素最小值。

fmod(x1, x2, /)

计算按元素浮点模运算。

frexp(x, /)

将浮点值拆分为尾数和 2 的指数。

frombuffer(buffer[, dtype, count, offset])

将缓冲区转换为 1-D JAX 数组。

fromfile(*args, **kwargs)

jnp.fromfile 的未实现 JAX 封装。

fromfunction(function, shape, *[, dtype])

从应用于索引的函数创建数组。

fromiter(*args, **kwargs)

jnp.fromiter 的未实现 JAX 封装。

frompyfunc(func, /, nin, nout, *[, identity])

从任意 JAX 兼容的标量函数创建 JAX ufunc。

fromstring(string[, dtype, count])

将文本字符串转换为 1 维 JAX 数组。

from_dlpack(x, /, *[, device, copy])

通过 DLPack 构建 JAX 数组。

full(shape, fill_value[, dtype, device])

创建一个填充指定值的数组。

full_like(a, fill_value[, dtype, shape, device])

创建一个具有与数组相同形状和 dtype 的、填充指定值的数组。

gcd(x1, x2)

计算两个数组的最大公约数。

通用()

numpy 标量类型的基类。

geomspace(start, stop[, num, endpoint, ...])

生成几何间隔的值。

get_printoptions()

返回当前打印选项。

gradient(f, *varargs[, axis, edge_order])

计算采样函数的数值梯度。

greater(x, y, /)

返回 x > y 的逐元素真值。

greater_equal(x, y, /)

返回 x >= y 的逐元素真值。

hamming(M)

返回大小为 M 的汉明窗。

hanning(M)

返回大小为 M 的汉宁窗。

heaviside(x1, x2, /)

计算赫维赛德阶跃函数。

histogram(a[, bins, range, weights, density])

计算一维直方图。

histogram_bin_edges(a[, bins, range, weights])

计算直方图的 bin 边缘。

histogram2d(x, y[, bins, range, weights, ...])

计算二维直方图。

histogramdd(sample[, bins, range, weights, ...])

计算 N 维直方图。

hsplit(ary, indices_or_sections)

将数组水平分割为子数组。

hstack(tup[, dtype])

水平堆叠数组。

hypot(x1, x2, /)

返回直角三角形给定边长的逐元素斜边。

i0(x)

计算第一类零阶修正贝塞尔函数。

identity(n[, dtype])

创建方形单位矩阵。

iinfo(int_type)

imag(val, /)

返回复数参数的逐元素虚部。

index_exp

一种构建数组索引元组的更友好的方法。

indices(dimensions[, dtype, sparse])

生成网格索引数组。

不精确()

所有数值标量类型的抽象基类,其范围内的值具有(可能)不精确的表示,例如浮点数。

inner(a, b, *[, precision, ...])

计算两个数组的内积。

insert(arr, obj, values[, axis])

在指定索引处将条目插入到数组中。

int_

int64 的别名

int16(x)

类型为 int16 的 JAX 标量构造函数。

int32(x)

类型为 int32 的 JAX 标量构造函数。

int64(x)

类型为 int64 的 JAX 标量构造函数。

int8(x)

类型为 int8 的 JAX 标量构造函数。

整数()

所有整数标量类型的抽象基类。

interp(x, xp, fp[, left, right, period])

一维线性插值。

intersect1d(ar1, ar2[, assume_unique, ...])

计算两个一维数组的集合交集。

invert(x, /)

计算输入的按位反转。

isclose(a, b[, rtol, atol, equal_nan])

检查两个数组的元素是否在容差范围内近似相等。

iscomplex(x)

返回显示输入是否为复数的布尔数组。

iscomplexobj(x)

检查输入是否为复数或包含复数元素的数组。

isdtype(dtype, kind)

返回一个布尔值,指示提供的 dtype 是否为指定的类型。

isfinite(x, /)

返回一个布尔数组,指示输入的每个元素是否为有限值。

isin(element, test_elements[, ...])

确定 element 中的元素是否出现在 test_elements 中。

isinf(x, /)

返回一个布尔数组,指示输入的每个元素是否为无限值。

isnan(x, /)

返回一个布尔数组,指示输入的每个元素是否为 NaN

isneginf(x, /[, out])

返回一个布尔数组,指示输入的每个元素是否为负无穷大。

isposinf(x, /[, out])

返回一个布尔数组,指示输入中的每个元素是否为正无穷大。

isreal(x)

返回一个布尔数组,显示输入是否为实数。

isrealobj(x)

检查输入是否不是复数或包含复数元素的数组。

isscalar(element)

如果输入是标量,则返回 True。

issubdtype(arg1, arg2)

如果 arg1 在类型层次结构中等于或低于 arg2,则返回 True。

iterable(y)

检查对象是否可以迭代。

ix_(*args)

从 N 个一维序列返回一个多维网格(开放网格)。

kaiser(M, beta)

返回大小为 M 的 Kaiser 窗口。

kron(a, b)

计算两个输入数组的克罗内克积。

lcm(x1, x2)

计算两个数组的最小公倍数。

ldexp(x1, x2, /)

计算 x1 * 2 ** x2

left_shift(x, y, /)

x 的位按元素方式左移 y 中指定的量。

less(x, y, /)

返回 x < y 的按元素真值。

less_equal(x, y, /)

返回 x <= y 的按元素真值。

lexsort(keys[, axis])

按字典顺序对键序列进行排序。

linspace(start, stop[, num, endpoint, ...])

返回间隔内的均匀间隔的数字。

load(file, *args, **kwargs)

从 npy 文件加载 JAX 数组。

log(x, /)

计算输入的按元素自然对数。

log10(x, /)

按元素计算 x 的以 10 为底的对数

log1p(x, /)

按元素计算 1 加输入的对数,log(x+1)

log2(x, /)

按元素计算 x 的以 2 为底的对数。

logaddexp

计算 log(exp(x1) + exp(x2)),避免溢出。

logaddexp2

以 2 为底的输入指数和的对数,避免溢出。

logical_and

按元素计算逻辑与运算。

logical_not(x, /)

按元素计算 NOT bool(x)。

logical_or

按元素计算逻辑或运算。

logical_xor

按元素计算逻辑异或运算。

logspace(start, stop[, num, endpoint, base, ...])

生成对数间隔的值。

mask_indices(n, mask_func[, k, size])

返回 (n, n) 数组的掩码的索引。

matmul(a, b, *[, precision, ...])

执行矩阵乘法。

matrix_transpose(x, /)

转置数组的最后两个维度。

matvec(x1, x2, /)

批量矩阵向量积。

max(a[, axis, out, keepdims, initial, where])

返回沿给定轴的数组元素的最大值。

maximum(x, y, /)

返回输入数组的按元素最大值。

mean(a[, axis, dtype, out, keepdims, where])

返回沿给定轴的数组元素的平均值。

median(a[, axis, out, overwrite_input, keepdims])

返回沿给定轴的数组元素的中位数。

meshgrid(*xi[, copy, sparse, indexing])

从 N 个一维向量构造 N 维网格数组。

mgrid

返回密集的 多维 “meshgrid”。

min(a[, axis, out, keepdims, initial, where])

返回沿给定轴的数组元素的最小值。

minimum(x, y, /)

返回输入数组的按元素最小值。

mod(x1, x2, /)

的别名jax.numpy.remainder()

modf(x, /[, out])

返回输入数组的按元素分数部分和整数部分。

moveaxis(a, source, destination)

将数组轴移动到新位置

multiply

按元素方式将两个数组相乘。

nan_to_num(x[, copy, nan, posinf, neginf])

替换数组中的 NaN 和无穷大条目。

nanargmax(a[, axis, out, keepdims])

返回数组最大值的索引,忽略 NaN。

nanargmin(a[, axis, out, keepdims])

返回数组最小值的索引,忽略 NaN。

nancumprod(a[, axis, dtype, out])

沿轴的元素的累积积,忽略 NaN 值。

nancumsum(a[, axis, dtype, out])

沿轴的元素的累积和,忽略 NaN 值。

nanmax(a[, axis, out, keepdims, initial, where])

返回沿给定轴的数组元素的最大值,忽略 NaN。

nanmean(a[, axis, dtype, out, keepdims, where])

返回沿给定轴的数组元素的平均值,忽略 NaN。

nanmedian(a[, axis, out, overwrite_input, ...])

返回沿给定轴的数组元素的中位数,忽略 NaN。

nanmin(a[, axis, out, keepdims, initial, where])

返回给定轴上数组元素的最小值,忽略 NaN 值。

nanpercentile(a, q[, axis, out, ...])

计算沿指定轴的数据的百分位数,忽略 NaN 值。

nanprod(a[, axis, dtype, out, keepdims, ...])

返回给定轴上数组元素的乘积,忽略 NaN 值。

nanquantile(a, q[, axis, out, ...])

计算沿指定轴的数据的分位数,忽略 NaN 值。

nanstd(a[, axis, dtype, out, ddof, ...])

计算沿给定轴的标准差,忽略 NaN 值。

nansum(a[, axis, dtype, out, keepdims, ...])

返回给定轴上数组元素的总和,忽略 NaN 值。

nanvar(a[, axis, dtype, out, ddof, ...])

计算沿给定轴的数组元素的方差,忽略 NaN 值。

ndarray

Array 的别名

ndim(a)

返回数组的维度数。

negative

返回输入的逐元素负值。

nextafter(x, y, /)

返回元素方向上 x 之后朝向 y 的下一个浮点数值。

nonzero(a, *[, size, fill_value])

返回数组中非零元素的索引。

not_equal(x, y, /)

返回 x != y 的逐元素真值。

number()

所有数值标量类型的抽象基类。

object_

任何 Python 对象。

ogrid

返回开放的多维“meshgrid”。

ones(shape[, dtype, device])

创建一个充满 1 的数组。

ones_like(a[, dtype, shape, device])

创建一个与数组具有相同形状和 dtype 的全 1 数组。

outer(a, b[, out])

计算两个数组的外积。

packbits(a[, axis, bitorder])

将位数组打包成 uint8 数组。

pad(array, pad_width[, mode])

向数组添加填充。

partition(a, kth[, axis])

返回数组的部分排序副本。

percentile(a, q[, axis, out, ...])

计算沿指定轴的数据的百分位数。

permute_dims(a, /, axes)

置换数组的轴/维度。

piecewise(x, condlist, funclist, *args, **kw)

计算在域中分段定义的函数。

place(arr, mask, vals, *[, inplace])

基于掩码更新数组元素。

poly(seq_of_zeros)

返回给定根序列的多项式的系数。

polyadd(a1, a2)

返回两个多项式的和。

polyder(p[, m])

返回指定阶数的多项式导数的系数。

polydiv(u, v, *[, trim_leading_zeros])

返回多项式除法的商和余数。

polyfit(x, y, deg[, rcond, full, w, cov])

数据最小二乘多项式拟合。

polyint(p[, m, k])

返回指定阶数的多项式积分的系数。

polymul(a1, a2, *[, trim_leading_zeros])

返回两个多项式的乘积。

polysub(a1, a2)

返回两个多项式的差。

polyval(p, x, *[, unroll])

计算多项式在特定值上的值。

positive(x, /)

返回输入的逐元素正值。

pow(x1, x2, /)

jax.numpy.power() 的别名

power(x1, x2, /)

计算 x1 的以 x2 为指数的逐元素幂。

printoptions(*args, **kwargs)

用于设置打印选项的上下文管理器。

prod(a[, axis, dtype, out, keepdims, ...])

返回给定轴上数组元素的乘积。

promote_types(a, b)

返回二元运算应将其参数强制转换的类型。

ptp(a[, axis, out, keepdims])

返回沿给定轴的峰峰值范围。

put(a, ind, v[, mode, inplace])

将元素放入给定索引的数组中。

put_along_axis(arr, indices, values, axis[, ...])

通过匹配 1 维索引和数据切片,将值放入目标数组中。

quantile(a, q[, axis, out, overwrite_input, ...])

计算沿指定轴的数据分位数。

r_

沿第一个轴连接切片、标量和类数组对象。

rad2deg(x, /)

将角度从弧度转换为度。

radians(x, /)

jax.numpy.deg2rad()的别名

ravel(a[, order])

将数组展平为一维形状。

ravel_multi_index(multi_index, dims[, mode, ...])

将多维索引转换为扁平索引。

real(val, /)

返回复数参数的逐元素实部。

reciprocal(x, /)

计算输入的逐元素倒数。

remainder(x1, x2, /)

返回除法的逐元素余数。

repeat(a, repeats[, axis, total_repeat_length])

从重复的元素构造一个数组。

reshape(a[, shape, order, newshape, copy])

返回数组的重塑副本。

resize(a, new_shape)

返回具有指定形状的新数组。

result_type(*args)

返回将 JAX 提升规则应用于输入的结果。

right_shift(x1, x2, /)

x1 的位向右移动 x2 中指定的量。

rint(x, /)

将 x 的元素四舍五入到最接近的整数。

roll(a, shift[, axis])

沿指定轴滚动数组的元素。

rollaxis(a, axis[, start])

将指定的轴滚动到给定的位置。

roots(p, *[, strip_zeros])

返回给定系数 p 的多项式的根。

rot90(m[, k, axes])

在 axes 指定的平面中,将数组逆时针旋转 90 度。

round(a[, decimals, out])

将输入均匀地舍入到给定的小数位数。

s_

一种构建数组索引元组的更友好的方法。

save(file, arr[, allow_pickle, fix_imports])

以 NumPy .npy 格式将数组保存到二进制文件中。

savez(file, *args[, allow_pickle])

以未压缩的 .npz 格式将多个数组保存到单个文件中。

searchsorted(a, v[, side, sorter, method])

在排序数组中执行二分查找。

select(condlist, choicelist[, default])

根据一系列条件选择值。

set_printoptions([precision, threshold, ...])

设置打印选项。

setdiff1d(ar1, ar2[, assume_unique, size, ...])

计算两个 1D 数组的集合差。

setxor1d(ar1, ar2[, assume_unique, size, ...])

计算两个数组中元素的集合式异或。

shape(a)

返回数组的形状。

sign(x, /)

返回输入的逐元素符号指示。

signbit(x, /)

返回数组元素的符号位。

signedinteger()

所有有符号整数标量类型的抽象基类。

sin(x, /)

计算输入的每个元素的三角正弦值。

sinc(x, /)

计算归一化 sinc 函数。

single

float32的别名

sinh(x, /)

计算输入的逐元素双曲正弦值。

size(a[, axis])

返回沿给定轴的元素数量。

sort(a[, axis, kind, order, stable, descending])

返回数组的排序副本。

sort_complex(a)

返回复数数组的排序副本。

spacing(x, /)

返回 x 与下一个相邻数字之间的间距。

split(ary, indices_or_sections[, axis])

将数组拆分为子数组。

sqrt(x, /)

计算输入数组的逐元素非负平方根。

square(x, /)

计算输入数组的逐元素平方。

squeeze(a[, axis])

从数组中删除一个或多个长度为 1 的轴。

stack(arrays[, axis, out, dtype])

沿新轴连接数组。

std(a[, axis, dtype, out, ddof, keepdims, ...])

计算沿给定轴的标准差。

subtract

逐元素减去两个数组。

sum(a[, axis, dtype, out, keepdims, ...])

求数组沿给定轴的元素之和。

swapaxes(a, axis1, axis2)

交换数组的两个轴。

take(a, indices[, axis, out, mode, ...])

从数组中提取元素。

take_along_axis(arr, indices, axis[, mode, ...])

从数组中提取元素。

tan(x, /)

计算输入中每个元素的三角正切值。

tanh(x, /)

计算输入的逐元素双曲正切值。

tensordot(a, b[, axes, precision, ...])

计算两个 N 维数组的张量点积。

tile(A, reps)

通过沿指定维度重复 A 来构造数组。

trace(a[, offset, axis1, axis2, dtype, out])

计算输入沿给定轴的对角线之和。

trapezoid(y[, x, dx, axis])

使用复合梯形规则沿给定轴积分。

transpose(a[, axes])

返回 N 维数组的转置版本。

tri(N[, M, k, dtype])

返回一个在对角线及以下为 1,其他地方为 0 的数组。

tril(m[, k])

返回数组的下三角部分。

tril_indices(n[, k, m])

返回大小为 (n, m) 的数组的下三角部分的索引。

tril_indices_from(arr[, k])

返回给定数组的下三角部分的索引。

trim_zeros(filt[, trim])

修剪输入数组的前导和/或尾随零。

triu(m[, k])

返回数组的上三角部分。

triu_indices(n[, k, m])

返回大小为 (n, m) 的数组的上三角部分的索引。

triu_indices_from(arr[, k])

返回给定数组的上三角部分的索引。

true_divide(x1, x2, /)

逐元素计算 x1 除以 x2 的结果

trunc(x)

将输入四舍五入到最接近于零的整数。

ufunc(func, /, nin, nout, *[, name, nargs, ...])

通用函数,对数组逐元素执行操作。

uint

uint64 的别名

uint16(x)

uint16 类型的 JAX 标量构造函数。

uint32(x)

uint32 类型的 JAX 标量构造函数。

uint64(x)

uint64 类型的 JAX 标量构造函数。

uint8(x)

uint8 类型的 JAX 标量构造函数。

union1d(ar1, ar2, *[, size, fill_value])

计算两个一维数组的并集。

unique(ar[, return_index, return_inverse, ...])

返回数组中的唯一值。

unique_all(x, /, *[, size, fill_value])

返回 x 中的唯一值,以及索引、逆索引和计数。

unique_counts(x, /, *[, size, fill_value])

返回 x 中的唯一值以及计数。

unique_inverse(x, /, *[, size, fill_value])

返回 x 中的唯一值,以及索引、逆索引和计数。

unique_values(x, /, *[, size, fill_value])

返回 x 中的唯一值,以及索引、逆索引和计数。

unpackbits(a[, axis, count, bitorder])

解包 uint8 数组中的位。

unravel_index(indices, shape)

将扁平索引转换为多维索引。

unstack(x, /, *[, axis])

沿轴解堆叠数组。

unsignedinteger()

所有无符号整数标量类型的抽象基类。

unwrap(p[, discont, axis, period])

展开周期信号。

vander(x[, N, increasing])

生成范德蒙矩阵。

var(a[, axis, dtype, out, ddof, keepdims, ...])

计算沿给定轴的方差。

vdot(a, b, *[, precision, ...])

执行两个一维向量的共轭乘法。

vecdot(x1, x2, /, *[, axis, precision, ...])

执行两个批量向量的共轭乘法。

vecmat(x1, x2, /)

批量共轭向量-矩阵乘积。

vectorize(pyfunc, *[, excluded, signature])

定义一个具有广播功能的向量化函数。

vsplit(ary, indices_or_sections)

将数组垂直拆分为子数组。

vstack(tup[, dtype])

垂直堆叠数组。

where(condition[, x, y, size, fill_value])

根据条件从两个数组中选择元素。

zeros(shape[, dtype, device])

创建一个充满零的数组。

zeros_like(a[, dtype, shape, device])

创建一个与数组具有相同形状和 dtype 的充满零的数组。

jax.numpy.fft#

fft(a[, n, axis, norm])

沿给定轴计算一维离散傅里叶变换。

fft2(a[, s, axes, norm])

沿给定轴计算二维离散傅里叶变换。

fftfreq(n[, d, dtype, device])

返回离散傅里叶变换的采样频率。

fftn(a[, s, axes, norm])

沿给定轴计算多维离散傅里叶变换。

fftshift(x[, axes])

将零频率 fft 分量移动到频谱的中心。

hfft(a[, n, axis, norm])

计算频谱具有埃尔米特对称性的数组的一维 FFT。

ifft(a[, n, axis, norm])

计算一维离散傅里叶逆变换。

ifft2(a[, s, axes, norm])

计算二维逆离散傅里叶变换。

ifftn(a[, s, axes, norm])

计算多维逆离散傅里叶变换。

ifftshift(x[, axes])

jax.numpy.fft.fftshift() 的逆运算。

ihfft(a[, n, axis, norm])

计算频谱具有厄米对称性的数组的一维逆 FFT。

irfft(a[, n, axis, norm])

计算实值一维逆离散傅里叶变换。

irfft2(a[, s, axes, norm])

计算实值二维逆离散傅里叶变换。

irfftn(a[, s, axes, norm])

计算实值多维逆离散傅里叶变换。

rfft(a[, n, axis, norm])

计算实值数组的一维离散傅里叶变换。

rfft2(a[, s, axes, norm])

计算实值数组的二维离散傅里叶变换。

rfftfreq(n[, d, dtype, device])

返回离散傅里叶变换的采样频率。

rfftn(a[, s, axes, norm])

计算实值数组的多维离散傅里叶变换。

jax.numpy.linalg#

cholesky(a, *[, upper])

计算矩阵的 Cholesky 分解。

cond(x[, p])

计算矩阵的条件数。

cross(x1, x2, /, *[, axis])

计算两个 3D 向量的叉积。

det(a)

计算数组的行列式。

diagonal(x, /, *[, offset])

提取矩阵或矩阵堆叠的对角线。

eig(a)

计算方阵的特征值和特征向量。

eigh(a[, UPLO, symmetrize_input])

计算埃尔米特矩阵的特征值和特征向量。

eigvals(a)

计算一般矩阵的特征值。

eigvalsh(a[, UPLO])

计算埃尔米特矩阵的特征值。

inv(a)

返回方阵的逆矩阵。

lstsq(a, b[, rcond, numpy_resid])

返回线性方程的最小二乘解。

matmul(x1, x2, /, *[, precision, ...])

执行矩阵乘法。

matrix_norm(x, /, *[, keepdims, ord])

计算矩阵或矩阵堆叠的范数。

matrix_power(a, n)

将方阵提升到整数幂。

matrix_rank(M[, rtol, tol])

计算矩阵的秩。

matrix_transpose(x, /)

转置矩阵或矩阵堆叠。

multi_dot(arrays, *[, precision])

高效地计算一系列数组之间的矩阵乘积。

norm(x[, ord, axis, keepdims])

计算矩阵或向量的范数。

outer(x1, x2, /)

计算两个一维数组的外积。

pinv(a[, rtol, hermitian, rcond])

计算矩阵的(Moore-Penrose)伪逆。

qr(a[, mode])

计算数组的 QR 分解。

slogdet(a, *[, method])

计算数组的行列式的符号和(自然)对数。

solve(a, b)

求解线性方程组。

svd(a[, full_matrices, compute_uv, ...])

计算奇异值分解。

svdvals(x, /)

计算矩阵的奇异值。

tensordot(x1, x2, /, *[, axes, precision, ...])

计算两个 N 维数组的张量点积。

tensorinv(a[, ind])

计算数组的张量逆。

tensorsolve(a, b[, axes])

求解张量方程 a x = b 中的 x。

trace(x, /, *[, offset, dtype])

计算矩阵的迹。

vector_norm(x, /, *[, axis, keepdims, ord])

计算向量或批量向量的向量范数。

vecdot(x1, x2, /, *[, axis, precision, ...])

计算两个数组的(批量)向量共轭点积。

JAX 数组#

JAX Array (以及它的别名 jax.numpy.ndarray)是 JAX 中的核心数组对象:你可以把它看作是 JAX 中与 numpy.ndarray 等价的对象。和 numpy.ndarray 一样,大多数用户不需要手动实例化 Array 对象,而是通过 jax.numpy 函数(如 array()arange()linspace() 以及上面列出的其他函数)来创建它们。

复制和序列化#

JAX Array 对象被设计为在适当的情况下与 Python 标准库工具无缝协作。

使用内置的 copy 模块时,当 copy.copy()copy.deepcopy() 遇到 Array 时,它等效于调用 copy() 方法,该方法将在与原始数组相同的设备上创建缓冲区的副本。这在跟踪/JIT 编译的代码中可以正常工作,尽管在这种情况下,编译器可能会省略复制操作。

当内置的 pickle 模块遇到 Array 时,它将通过紧凑的位表示进行序列化,其方式类似于 pickled 的 numpy.ndarray 对象。当反序列化时,结果将是在默认设备上 的新的 Array 对象。这是因为通常,序列化和反序列化可能发生在不同的运行时环境中,并且没有通用的方法将一个运行时的设备 ID 映射到另一个运行时的设备 ID。如果在跟踪/JIT 编译的代码中使用 pickle,则会导致 ConcretizationTypeError

Python 数组 API 标准#

注意

在 JAX v0.4.32 之前,你必须 import jax.experimental.array_api 才能为 JAX 数组启用数组 API。在 JAX v0.4.32 之后,不再需要导入此模块,并且会引发弃用警告。

从 JAX v0.4.32 开始,jax.Arrayjax.numpyPython 数组 API 标准兼容。你可以通过 jax.Array.__array_namespace__() 访问数组 API 命名空间。

>>> def f(x):
...   nx = x.__array_namespace__()
...   return nx.sin(x) ** 2 + nx.cos(x) ** 2

>>> import jax.numpy as jnp
>>> x = jnp.arange(5)
>>> f(x).round()
Array([1., 1., 1., 1., 1.], dtype=float32)

JAX 在一些地方偏离了标准,主要是因为 JAX 数组是不可变的,不支持原地更新。其中一些不兼容性正在通过 array-api-compat 模块解决。

有关更多信息,请参阅 Python 数组 API 标准 文档。