jax.numpy.reshape#
- jax.numpy.reshape(a, shape=None, order='C', *, newshape=Deprecated, copy=None)[源代码]#
返回数组的重塑副本。
JAX 的
numpy.reshape()
实现,根据jax.lax.reshape()
实现。- 参数:
a (类数组) – 要重塑的输入数组
shape (DimSize | Shape | None | None) – 整数或整数序列,表示新的形状,必须与输入数组的大小匹配。如果任何一个维度的大小被指定为
-1
,它将被替换为一个值,使得输出具有正确的大小。order (str) –
'F'
或'C'
,指定重塑操作应采用列优先(Fortran 风格,"F"
)还是行优先(C 风格,"C"
)顺序;默认值为"C"
。JAX 不支持order="A"
。copy (bool | None | None) – JAX 未使用;JAX 始终返回一个副本,但在 JIT 下,编译器可能会优化掉这些副本。
newshape (DimSize | Shape | DeprecatedArg) – 已弃用的
shape
参数别名。如果使用,将会导致DeprecationWarning
。
- 返回:
返回具有指定形状的输入数组的重塑副本。
- 返回类型:
注意
与
numpy.reshape()
不同,jax.numpy.reshape()
将返回输入数组的副本,而不是视图。但是,在 JIT 下,编译器会在可能的情况下优化掉这些副本,因此在实践中不会对性能产生影响。另请参阅
jax.Array.reshape()
:通过数组方法实现等效功能。jax.numpy.ravel()
:将数组展平为一维形状。jax.numpy.squeeze()
:从数组的形状中删除一个或多个长度为 1 的轴。
示例
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.reshape(x, 6) Array([1, 2, 3, 4, 5, 6], dtype=int32) >>> jnp.reshape(x, (3, 2)) Array([[1, 2], [3, 4], [5, 6]], dtype=int32)
您可以使用
-1
来自动计算与输入大小一致的形状>>> jnp.reshape(x, -1) # -1 is inferred to be 6 Array([1, 2, 3, 4, 5, 6], dtype=int32) >>> jnp.reshape(x, (-1, 2)) # -1 is inferred to be 3 Array([[1, 2], [3, 4], [5, 6]], dtype=int32)
重塑中轴的默认顺序是 C 风格的行优先顺序。要使用 Fortran 风格的列优先顺序,请指定
order='F'
>>> jnp.reshape(x, 6, order='F') Array([1, 4, 2, 5, 3, 6], dtype=int32) >>> jnp.reshape(x, (3, 2), order='F') Array([[1, 5], [4, 3], [2, 6]], dtype=int32)
为方便起见,此功能也可以通过
jax.Array.reshape()
方法获得>>> x.reshape(3, 2) Array([[1, 2], [3, 4], [5, 6]], dtype=int32)