jax.numpy.load#
- jax.numpy.load(file, *args, **kwargs)[源代码]#
从 npy 文件加载 JAX 数组。
numpy.load()
的 JAX 包装器。此函数是
numpy.load()
的一个简单包装器,但在使用numpy.save()
或jax.numpy.save()
创建的.npy
文件的情况下,输出将作为jax.Array
返回,并且bfloat16
数据类型将被恢复。对于.npz
文件,结果将作为普通的 NumPy 数组返回。此函数需要具体的数组输入,并且不兼容像
jax.jit()
或jax.vmap()
这样的转换。- 参数:
file (IO[bytes] | str | os.PathLike[Any]) – 包含数组数据的字符串、字节或类路径对象。
args (Any) – 有关其他参数,请参阅
numpy.load()
kwargs (Any) – 有关其他参数,请参阅
numpy.load()
- 返回:
存储在文件中的数组。
- 返回类型:
另请参阅
jax.numpy.save()
:将数组保存到文件。
示例
>>> import io >>> f = io.BytesIO() # use an in-memory file-like object. >>> x = jnp.array([2, 4, 6, 8], dtype='bfloat16') >>> jnp.save(f, x) >>> f.seek(0) 0 >>> jnp.load(f) Array([2, 4, 6, 8], dtype=bfloat16)