jax.lax.split# jax.lax.split(operand, sizes, axis=0)[源代码]# 沿 axis 分割数组。 参数: operand (ArrayLike) – 要分割的数组 sizes (Sequence[int]) – 分割数组的大小序列。这些大小的总和必须等于 operand 的 axis 维度的大小。 axis (int) – 沿其分割数组的轴。 返回值: 一个长度为 len(sizes) 的数组序列。 如果 sizes 是 [s1, s2, ...],此函数将返回沿 axis 方向取出的,大小分别为 s1,s2 的块。 返回类型: Sequence[Array]