jax.numpy.copy

内容

jax.numpy.copy#

jax.numpy.copy(a, order=None)[source]#

返回数组的副本。

JAX 实现 numpy.copy().

参数:
  • a (ArrayLike) – 要复制的类数组对象

  • order (str | None | None) – 在 JAX 中未实现

返回:

输入数组 a 的副本。

返回类型:

Array

另请参阅

示例

由于 JAX 数组是不可变的,在大多数情况下,不需要显式复制数组。一个例外是当使用带有捐赠参数的函数时(请参阅 jax.jit()donate_argnums 参数)。

>>> f = jax.jit(lambda x: 2 * x, donate_argnums=0)
>>> x = jnp.arange(4)
>>> y = f(x)
>>> print(y)
[0 2 4 6]

因为我们标记了 x 为捐赠,所以原始数组不再可用。

>>> print(x)  
Traceback (most recent call last):
RuntimeError: Array has been deleted with shape=int32[4].

在这种情况下,显式复制将使您能够继续访问原始缓冲区。

>>> x = jnp.arange(4)
>>> y = f(x.copy())
>>> print(y)
[0 2 4 6]
>>> print(x)
[0 1 2 3]