jax.numpy.copy#

jax.numpy.copy(a, order=None)[源代码]#

返回数组的副本。

JAX 实现的 numpy.copy().

参数:
  • a (ArrayLike) – 要复制的类数组对象

  • order (str | None | None) – 在 JAX 中未实现

返回:

输入数组 a 的副本。

返回类型:

数组

另请参阅

示例

由于 JAX 数组是不可变的,因此在大多数情况下,不需要显式数组复制。一个例外是使用带有捐赠参数的函数(请参阅 jax.jit()donate_argnums 参数)。

>>> f = jax.jit(lambda x: 2 * x, donate_argnums=0)
>>> x = jnp.arange(4)
>>> y = f(x)
>>> print(y)
[0 2 4 6]

因为我们标记了 x 被捐赠,所以原始数组不再可用

>>> print(x)  
Traceback (most recent call last):
RuntimeError: Array has been deleted with shape=int32[4].

在这种情况下,显式复制将使你保留对原始缓冲区的访问权限

>>> x = jnp.arange(4)
>>> y = f(x.copy())
>>> print(y)
[0 2 4 6]
>>> print(x)
[0 1 2 3]