jax.lax.collapse# jax.lax.collapse(operand, start_dimension, stop_dimension=None)[source]# 将数组的维度折叠成单个维度。 例如,如果 operand 是一个形状为 [2, 3, 4] 的数组,那么 collapse(operand, 0, 2).shape == [6, 4]。折叠维度的元素按从大到小的顺序排列,即最低编号的维度作为变化最慢的维度。 参数: operand (Array) – 输入数组。 start_dimension (int) – 要折叠的维度的开始位置(包含)。 stop_dimension (int | None | None) – 要折叠的维度的结束位置(不包含)。传递 None 表示折叠 start 之后的所有维度。 返回值: 一个数组,其中维度 [start_dimension, stop_dimension) 已折叠(展平)成单个维度。 返回值类型: Array