jax.numpy.reshape#
- jax.numpy.reshape(a, shape=None, order='C', *, newshape=Deprecated, copy=None)[源代码]#
返回一个重塑后的数组副本。
numpy.reshape()
的 JAX 实现,基于jax.lax.reshape()
实现。- 参数:
a (ArrayLike) – 要重塑的输入数组
shape (DimSize | Shape | None | None) – 整数或整数序列,给出新形状,必须与输入数组的大小匹配。如果任何一个维度的大小为
-1
,它将被替换为一个值,使输出具有正确的大小。order ( str ) –
'F'
或'C'
,指定重塑操作应使用列优先(Fortran 风格,"F"
)还是行优先(C 风格,"C"
)顺序;默认值为"C"
。JAX 不支持order="A"
。copy ( bool | None | None ) – JAX 未使用;JAX 始终返回副本,但在 JIT 下,编译器可能会优化掉这些副本。
newshape (DimSize | Shape | DeprecatedArg) – 已弃用的
shape
参数的别名。如果使用,将会导致DeprecationWarning
。
- 返回值:
具有指定形状的输入数组的重塑副本。
- 返回类型:
注释
与
numpy.reshape()
不同,jax.numpy.reshape()
将返回输入数组的副本,而不是视图。但是,在 JIT 下,编译器会在可能的情况下优化掉这些副本,因此在实践中不会对性能产生影响。另请参阅
jax.Array.reshape()
:通过数组方法实现的等效功能。jax.numpy.ravel()
:将数组展平为 1D 形状。jax.numpy.squeeze()
:从数组形状中删除一个或多个长度为 1 的轴。
示例
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.reshape(x, 6) Array([1, 2, 3, 4, 5, 6], dtype=int32) >>> jnp.reshape(x, (3, 2)) Array([[1, 2], [3, 4], [5, 6]], dtype=int32)
可以使用
-1
自动计算与输入大小一致的形状>>> jnp.reshape(x, -1) # -1 is inferred to be 6 Array([1, 2, 3, 4, 5, 6], dtype=int32) >>> jnp.reshape(x, (-1, 2)) # -1 is inferred to be 3 Array([[1, 2], [3, 4], [5, 6]], dtype=int32)
重塑中轴的默认顺序是 C 风格的行优先顺序。要使用 Fortran 风格的列优先顺序,请指定
order='F'
>>> jnp.reshape(x, 6, order='F') Array([1, 4, 2, 5, 3, 6], dtype=int32) >>> jnp.reshape(x, (3, 2), order='F') Array([[1, 5], [4, 3], [2, 6]], dtype=int32)
为方便起见,此功能也可以通过
jax.Array.reshape()
方法使用>>> x.reshape(3, 2) Array([[1, 2], [3, 4], [5, 6]], dtype=int32)