公共 API:jax
包#
子包#
jax.numpy
模块jax.scipy
模块jax.lax
模块jax.random
模块jax.sharding
模块jax.debug
模块jax.dlpack
模块jax.distributed
模块jax.dtypes
模块jax.ffi
模块jax.extend.ffi
模块 (已弃用)jax.flatten_util
模块jax.image
模块jax.nn
模块jax.ops
模块jax.profiler
模块jax.stages
模块jax.tree
模块jax.tree_util
模块jax.typing
模块jax.export
模块jax.extend
模块jax.example_libraries
模块jax.experimental
模块
配置#
jax_check_tracer_leaks 配置选项的上下文管理器。 |
|
jax_check_tracer_leaks 配置选项的上下文管理器。 |
|
jax_debug_nans 配置选项的上下文管理器。 |
|
jax_debug_infs 配置选项的上下文管理器。 |
|
jax_default_device 配置选项的上下文管理器。 |
|
jax_default_matmul_precision 配置选项的上下文管理器。 |
|
jax_default_prng_impl 配置选项的上下文管理器。 |
|
jax_enable_checks 配置选项的上下文管理器。 |
|
jax_enable_custom_prng 配置选项的上下文管理器(临时)。 |
|
jax_enable_custom_vjp_by_custom_transpose 配置选项的上下文管理器(临时)。 |
|
jax_log_compiles 配置选项的上下文管理器。 |
|
jax_numpy_rank_promotion 配置选项的上下文管理器。 |
|
|
控制所有传输的传输保护级别的上下文管理器。 |
即时编译 (jit
)#
|
为使用 XLA 进行即时编译设置 |
|
禁用其动态上下文下的 |
确保在跟踪/编译时进行评估(或报错)的上下文管理器。 |
|
|
创建一个函数,该函数给定示例参数,生成其 jaxpr。 |
|
计算 |
|
数组的形状、dtype 和其他静态属性的容器。 |
|
将 |
|
将 |
返回默认 XLA 后端的平台名称。 |
|
|
在暂存 JAX 计算时,将用户指定的名称添加到函数。 |
|
将用户指定的名称添加到 JAX 名称堆栈的上下文管理器。 |
尝试在 pytree 叶子上调用 |
|
|
使用指定的形状和轴名称创建高效的网格。 |
自动微分#
|
创建一个评估 |
|
创建一个评估 |
|
|
|
使用前向模式 AD 逐列评估 |
|
使用反向模式 AD 逐行评估 |
|
|
|
计算 |
使用 |
|
|
转置一个承诺是线性的函数。 |
|
计算 |
|
用于定义自定义 VJP 规则(又名自定义梯度)的便捷函数。 |
|
闭包转换实用程序,用于高阶自定义导数。 |
|
使 |
自定义#
custom_jvp
#
|
为自定义 JVP 规则定义设置可 JAX 转换的函数。 |
|
为此实例表示的函数定义自定义 JVP 规则。 |
|
用于为每个参数单独定义 JVP 的便捷包装器。 |
custom_vjp
#
|
为自定义 VJP 规则定义设置可 JAX 转换的函数。 |
|
为此实例表示的函数定义自定义 VJP 规则。 |
custom_batching
#
自定义可 JAX 转换函数的 vmap 行为。 |
|
|
为此 custom_vmap 函数定义 vmap 规则。 |
|
jax.Array (jax.Array
)#
|
JAX 的数组基类 |
|
通过从 |
|
从一系列在单个设备上的 |
|
使用进程中可用的数据创建分布式张量。 |
数组属性和方法#
可寻址分片的列表。 |
|
|
测试给定轴上的所有数组元素是否都为 True。 |
|
测试给定轴上的任何数组元素是否为 True。 |
|
返回最大值的索引。 |
|
返回最小值的索引。 |
|
返回部分排序数组的索引。 |
|
返回对数组进行排序的索引。 |
|
复制数组并转换为指定的 dtype。 |
用于索引更新功能的辅助属性。 |
|
|
构造一个从多个数组的元素中选择的数组。 |
|
返回一个其值限制在指定范围内的数组。 |
|
返回沿给定轴的此数组的选定切片。 |
数组是否已提交。 |
|
返回数组的复共轭。 |
|
返回数组的复共轭。 |
|
返回数组的副本。 |
|
异步地将 |
|
|
返回数组的累积乘积。 |
|
返回数组的累积和。 |
与 Array API 兼容的设备属性。 |
|
|
返回数组中指定的对角线。 |
|
计算两个数组的点积。 |
数组的数据类型 ( |
|
请改用 |
|
|
将数组展平为一维形状。 |
全局分片的列表。 |
|
返回数组的虚部。 |
|
此数组是否完全可寻址? |
|
此数组是否完全复制? |
|
|
将数组的元素复制到标准的 Python 标量并返回它。 |
一个数组元素以字节为单位的长度。 |
|
|
返回给定轴上的数组元素的最大值。 |
|
返回给定轴上的数组元素的平均值。 |
|
返回给定轴上的数组元素的最小值。 |
数组元素消耗的总字节数。 |
|
数组中的维度数。 |
|
|
返回数组非零元素的索引。 |
|
返回给定轴上数组元素的乘积。 |
|
返回给定轴上的峰峰值范围。 |
|
将数组展平为一维形状。 |
返回数组的实部。 |
|
|
从重复的元素构造数组。 |
|
返回一个包含相同数据的新形状的数组。 |
|
将数组元素四舍五入到给定的小数位。 |
|
在排序的数组中执行二进制搜索。 |
数组的形状。 |
|
数组的分片。 |
|
数组中元素的总数。 |
|
|
返回数组的排序副本。 |
|
从数组中删除一个或多个长度为 1 的轴。 |
|
计算给定轴上的标准差。 |
|
计算给定轴上数组元素的总和。 |
|
交换数组的两个轴。 |
|
从数组中获取元素。 |
|
返回指定设备上数组的副本。 |
|
返回沿对角线的总和。 |
|
返回一个轴已转置的数组副本。 |
|
计算给定轴上的方差。 |
|
返回数组的按位副本,将其视为新的 dtype。 |
计算全轴数组转置。 |
|
计算(批处理)矩阵转置。 |
矢量化 (vmap
)#
|
矢量化映射。 |
|
定义一个具有广播功能的矢量化函数。 |
并行化 (pmap
)#
|
支持集体操作的并行映射。 |
|
返回给定后端的所有设备的列表。 |
|
类似于 |
|
返回此进程的整数进程索引。 |
|
返回设备总数。 |
|
返回此进程可寻址的设备数量。 |
|
返回与后端关联的 JAX 进程数量。 |
|
返回与后端关联的所有 JAX 进程索引的列表。 |
回调函数#
|
调用纯 Python 回调函数。 |
|
调用非纯 Python 回调函数。 |
|
调用可分阶段的 Python 回调函数。 |
|
打印值,并可在分阶段的 JAX 函数中使用。 |
其他#
可用设备的描述符。 |
|
|
返回一个包含本地环境和 JAX 安装信息的字符串。 |
|
返回 platform 后端中的所有活动数组。 |
清除所有编译和分阶段缓存。 |