jax.experimental.array_api
模块#
注意
自 JAX v0.4.32 起,jax.experimental.array_api
模块已弃用,不再需要导入 jax.experimental.array_api
。{mod}`jax.numpy` 默认直接实现数组 API 标准。有关详细信息,请参阅 Python 数组 API 标准。
此模块包含对 Python 数组 API 标准的实验性 JAX 支持。目前对此的支持是实验性的,尚未完全完成。
使用示例
>>> from jax.experimental import array_api as xp
>>> xp.__array_api_version__
'2023.12'
>>> arr = xp.arange(1000)
>>> arr.sum()
Array(499500, dtype=int32)
xp
命名空间是符合数组 API 的 jax.numpy
的对等物,并且实现了标准中列出的大部分 API。