jax.numpy.reshape#

jax.numpy.reshape(a, shape=None, order='C', *, newshape=Deprecated, copy=None)[源代码]#

返回一个重塑后的数组副本。

numpy.reshape() 的 JAX 实现,基于 jax.lax.reshape() 实现。

参数:
  • a (ArrayLike) – 要重塑的输入数组

  • shape (DimSize | Shape | None | None) – 整数或整数序列,给出新形状,必须与输入数组的大小匹配。如果任何一个维度的大小为 -1,它将被替换为一个值,使输出具有正确的大小。

  • order ( str ) – 'F''C',指定重塑操作应使用列优先(Fortran 风格,"F")还是行优先(C 风格,"C")顺序;默认值为 "C"。JAX 不支持 order="A"

  • copy ( bool | None | None ) – JAX 未使用;JAX 始终返回副本,但在 JIT 下,编译器可能会优化掉这些副本。

  • newshape (DimSize | Shape | DeprecatedArg) – 已弃用的 shape 参数的别名。如果使用,将会导致 DeprecationWarning

返回值:

具有指定形状的输入数组的重塑副本。

返回类型:

Array

注释

numpy.reshape() 不同,jax.numpy.reshape() 将返回输入数组的副本,而不是视图。但是,在 JIT 下,编译器会在可能的情况下优化掉这些副本,因此在实践中不会对性能产生影响。

另请参阅

示例

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> jnp.reshape(x, 6)
Array([1, 2, 3, 4, 5, 6], dtype=int32)
>>> jnp.reshape(x, (3, 2))
Array([[1, 2],
       [3, 4],
       [5, 6]], dtype=int32)

可以使用 -1 自动计算与输入大小一致的形状

>>> jnp.reshape(x, -1)  # -1 is inferred to be 6
Array([1, 2, 3, 4, 5, 6], dtype=int32)
>>> jnp.reshape(x, (-1, 2))  # -1 is inferred to be 3
Array([[1, 2],
       [3, 4],
       [5, 6]], dtype=int32)

重塑中轴的默认顺序是 C 风格的行优先顺序。要使用 Fortran 风格的列优先顺序,请指定 order='F'

>>> jnp.reshape(x, 6, order='F')
Array([1, 4, 2, 5, 3, 6], dtype=int32)
>>> jnp.reshape(x, (3, 2), order='F')
Array([[1, 5],
       [4, 3],
       [2, 6]], dtype=int32)

为方便起见,此功能也可以通过 jax.Array.reshape() 方法使用

>>> x.reshape(3, 2)
Array([[1, 2],
       [3, 4],
       [5, 6]], dtype=int32)