jax.Array#
- class jax.Array#
JAX 的数组基类
jax.Array
是 JAX 数组和追踪器的公共接口,用于实例检查和类型注解。它的主要应用是实例检查和类型注解;例如:x = jnp.arange(5) isinstance(x, jax.Array) # returns True both inside and outside traced functions. def f(x: Array) -> Array: # type annotations are valid for traced and non-traced types. return x
jax.Array
不应该直接用于创建数组;相反,您应该使用jax.numpy
中提供的数组创建例程,例如jax.numpy.array()
,jax.numpy.zeros()
,jax.numpy.ones()
,jax.numpy.full()
,jax.numpy.arange()
等。- __init__()#
方法
__init__
()addressable_data
(index)返回特定索引处的可寻址数据数组。
all
([axis, out, keepdims, where])测试给定轴上的所有数组元素是否都为 True。
any
([axis, out, keepdims, where])测试给定轴上的任何数组元素是否为 True。
argmax
([axis, out, keepdims])返回最大值的索引。
argmin
([axis, out, keepdims])返回最小值的索引。
argpartition
(kth[, axis])返回部分排序数组的索引。
argsort
([axis, kind, order, stable, descending])返回排序数组的索引。
astype
(dtype[, copy, device])复制数组并强制转换为指定的 dtype。
choose
(choices[, out, mode])构造一个从多个数组的元素中选择的数组。
clip
([min, max])返回一个值限制在指定范围内的数组。
compress
(condition[, axis, out, size, ...])返回沿给定轴的此数组的选定切片。
conj
()返回数组的复共轭。
返回数组的复共轭。
copy
()返回数组的副本。
异步地将
Array
复制到主机。cumprod
([axis, dtype, out])返回数组的累积乘积。
cumsum
([axis, dtype, out])返回数组的累积和。
diagonal
([offset, axis1, axis2])返回数组中指定的对角线。
dot
(b, *[, precision, preferred_element_type])计算两个数组的点积。
flatten
([order])将数组展平为一维形状。
item
(*args)将数组的元素复制到标准的 Python 标量并返回。
max
([axis, out, keepdims, initial, where])返回给定轴上的数组元素的最大值。
mean
([axis, dtype, out, keepdims, where])返回给定轴上的数组元素的平均值。
min
([axis, out, keepdims, initial, where])返回给定轴上的数组元素的最小值。
nonzero
(*[, fill_value, size])返回数组非零元素的索引。
prod
([axis, dtype, out, keepdims, initial, ...])返回给定轴上数组元素的乘积。
ptp
([axis, out, keepdims])返回给定轴上的峰峰值范围。
ravel
([order])将数组展平为一维形状。
repeat
(repeats[, axis, total_repeat_length])从重复元素构造数组。
reshape
(*args[, order])返回一个包含相同数据且具有新形状的数组。
round
([decimals, out])将数组元素四舍五入到给定的小数位。
searchsorted
(v[, side, sorter, method])在排序数组中执行二进制搜索。
sort
([axis, kind, order, stable, descending])返回数组的排序副本。
squeeze
([axis])从数组中删除一个或多个长度为 1 的轴。
std
([axis, dtype, out, ddof, keepdims, ...])计算沿给定轴的标准偏差。
sum
([axis, dtype, out, keepdims, initial, ...])求给定轴上数组元素的总和。
swapaxes
(axis1, axis2)交换数组的两个轴。
take
(indices[, axis, out, mode, ...])从数组中获取元素。
to_device
(device, *[, stream])返回指定设备上数组的副本
trace
([offset, axis1, axis2, dtype, out])返回沿对角线的总和。
transpose
(*args)返回一个轴转置的数组副本。
var
([axis, dtype, out, ddof, keepdims, ...])计算沿给定轴的方差。
view
([dtype, type])返回数组的按位副本,视为新的 dtype。
属性
计算全轴数组转置。
可寻址分片的列表。
用于索引更新功能的辅助属性。
数组是否已提交。
与数组 API 兼容的设备属性。
数组的数据类型 (
numpy.dtype
)。请改用
flatten()
。全局分片的列表。
返回数组的虚部。
这个数组是否完全可寻址?
这个数组是否完全复制?
一个数组元素的长度(以字节为单位)。
计算(批处理的)矩阵转置。
数组元素消耗的总字节数。
数组的维度数。
返回数组的实部。
数组的形状。
数组的分片。
数组中的元素总数。