jax.numpy.load#

jax.numpy.load(file, *args, **kwargs)[源代码]#

从 npy 文件加载 JAX 数组。

numpy.load() 的 JAX 包装器。

此函数是 numpy.load() 的一个简单包装器,但在使用 numpy.save()jax.numpy.save() 创建的 .npy 文件的情况下,输出将作为 jax.Array 返回,并且 bfloat16 数据类型将被恢复。对于 .npz 文件,结果将作为普通的 NumPy 数组返回。

此函数需要具体的数组输入,并且不兼容像 jax.jit()jax.vmap() 这样的转换。

参数:
  • file (IO[bytes] | str | os.PathLike[Any]) – 包含数组数据的字符串、字节或类路径对象。

  • args (Any) – 有关其他参数,请参阅 numpy.load()

  • kwargs (Any) – 有关其他参数,请参阅 numpy.load()

返回:

存储在文件中的数组。

返回类型:

Array

另请参阅

示例

>>> import io
>>> f = io.BytesIO()  # use an in-memory file-like object.
>>> x = jnp.array([2, 4, 6, 8], dtype='bfloat16')
>>> jnp.save(f, x)
>>> f.seek(0)
0
>>> jnp.load(f)
Array([2, 4, 6, 8], dtype=bfloat16)