jax.lax.reshape#

jax.lax.reshape(operand, new_sizes, dimensions=None, sharding=None)[源代码]#

封装 XLA 的 Reshape 操作符。

对于插入/删除大小为 1 的维度,请优先使用 lax.squeeze / lax.expand_dims。这些操作保留了关于轴标识的信息,这对于高级转换规则可能很有用。

参数:
  • operand (ArrayLike) – 要重塑形状的数组。

  • new_sizes (Shape) – 指定结果形状的整数序列。最终数组的大小必须与输入的大小匹配。

  • dimensions (Sequence[int] | None | None) – 可选的整数序列,指定输入形状的置换顺序。如果指定,则长度必须与 operand.shape 匹配。

  • sharding (NamedSharding | P | None | None)

返回:

重塑形状后的数组。

返回类型:

out

示例

从一维到二维的简单重塑形状

>>> x = jnp.arange(6)
>>> y = reshape(x, (2, 3))
>>> y
Array([[0, 1, 2],
             [3, 4, 5]], dtype=int32)

重塑形状回一维

>>> reshape(y, (6,))
Array([0, 1, 2, 3, 4, 5], dtype=int32)

通过维度置换重塑形状到一维

>>> reshape(y, (6,), (1, 0))
Array([0, 3, 1, 4, 2, 5], dtype=int32)