jax.numpy.reshape#
- jax.numpy.reshape(a, shape=None, order='C', *, newshape=Deprecated, copy=None)[source]#
返回数组的重塑副本。
JAX 实现
numpy.reshape()
,使用jax.lax.reshape()
实现。- 参数:
a (ArrayLike) – 要重塑的输入数组
shape (DimSize | Shape | None | None) – 给出新形状的整数或整数序列,必须与输入数组的大小匹配。如果任何单个维度的大小为
-1
,则将用一个值替换它,以使输出具有正确的大小。order (str) –
'F'
或'C'
,指定重塑应用列主序(Fortran 风格,"F"
)还是行主序(C 风格,"C"
);默认值为"C"
。JAX 不支持order="A"
。copy (bool | None | None) – JAX 未使用;JAX 始终返回副本,但在 JIT 下,编译器可能会优化掉此类副本。
newshape (DimSize | Shape | DeprecatedArg) –
shape
参数的弃用别名。如果使用,将导致DeprecationWarning
。
- 返回:
具有指定形状的输入数组的重塑副本。
- 返回类型:
笔记
与
numpy.reshape()
不同,jax.numpy.reshape()
将返回输入数组的副本,而不是视图。但是,在 JIT 下,编译器将在可能的情况下优化掉此类副本,因此在实践中不会影响性能。另请参阅
jax.Array.reshape()
: 通过数组方法实现等效功能。jax.numpy.ravel()
: 将数组展平为一维形状。jax.numpy.squeeze()
: 从数组形状中移除一个或多个长度为 1 的轴。
示例
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.reshape(x, 6) Array([1, 2, 3, 4, 5, 6], dtype=int32) >>> jnp.reshape(x, (3, 2)) Array([[1, 2], [3, 4], [5, 6]], dtype=int32)
您可以使用
-1
自动计算与输入大小一致的形状>>> jnp.reshape(x, -1) # -1 is inferred to be 6 Array([1, 2, 3, 4, 5, 6], dtype=int32) >>> jnp.reshape(x, (-1, 2)) # -1 is inferred to be 3 Array([[1, 2], [3, 4], [5, 6]], dtype=int32)
重塑中轴的默认顺序是 C 样式的逐行主序。要使用 Fortran 样式的逐列主序,请指定
order='F'
>>> jnp.reshape(x, 6, order='F') Array([1, 4, 2, 5, 3, 6], dtype=int32) >>> jnp.reshape(x, (3, 2), order='F') Array([[1, 5], [4, 3], [2, 6]], dtype=int32)
为了方便起见,此功能也可以通过
jax.Array.reshape()
方法获得>>> x.reshape(3, 2) Array([[1, 2], [3, 4], [5, 6]], dtype=int32)