jax.experimental.array_api 模块

jax.experimental.array_api 模块#

注意

从 JAX v0.4.32 开始,jax.experimental.array_api 模块已弃用,不再需要导入 jax.experimental.array_api。{mod}`jax.numpy` 默认直接实现数组 API 标准。有关详细信息,请参阅 Python 数组 API 标准

此模块包含对 Python 数组 API 标准 的实验性 JAX 支持。目前对它的支持是实验性的,并非完全完成。

示例用法

>>> from jax.experimental import array_api as xp

>>> xp.__array_api_version__
'2023.12'

>>> arr = xp.arange(1000)

>>> arr.sum()
Array(499500, dtype=int32)

xp 命名空间是 jax.numpy 的数组 API 兼容模拟,它实现了标准中列出的 API 的大部分。