jax.numpy.load#
- jax.numpy.load(file, *args, **kwargs)[源代码]#
从 npy 文件加载 JAX 数组。
JAX 对
numpy.load()
的封装。此函数是
numpy.load()
的简单封装,但在使用numpy.save()
或jax.numpy.save()
创建的.npy
文件的情况下,输出将作为jax.Array
返回,并且bfloat16
数据类型将被恢复。对于.npz
文件,结果将作为普通的 NumPy 数组返回。此函数需要具体的数组输入,并且与
jax.jit()
或jax.vmap()
等变换不兼容。- 参数:
file (IO[bytes] | str | os.PathLike[Any]) – 包含数组数据的字符串、字节或类路径对象。
args (Any) – 对于其他参数,请参阅
numpy.load()
kwargs (Any) – 对于其他参数,请参阅
numpy.load()
- 返回:
存储在文件中的数组。
- 返回类型:
另请参阅
jax.numpy.save()
: 将数组保存到文件。
示例
>>> import io >>> f = io.BytesIO() # use an in-memory file-like object. >>> x = jnp.array([2, 4, 6, 8], dtype='bfloat16') >>> jnp.save(f, x) >>> f.seek(0) 0 >>> jnp.load(f) Array([2, 4, 6, 8], dtype=bfloat16)