jax.numpy.dsplit

内容

jax.numpy.dsplit#

jax.numpy.dsplit(ary, indices_or_sections)[source]#

沿深度方向将数组拆分为子数组。

JAX 对 numpy.dsplit() 的实现。

有关详细信息,请参阅 jax.numpy.split() 的文档。 dsplit 等效于 split,其中 axis=2

示例

>>> x = jnp.arange(12).reshape(3, 1, 4)
>>> print(x)
[[[ 0  1  2  3]]

 [[ 4  5  6  7]]

 [[ 8  9 10 11]]]
>>> x1, x2 = jnp.dsplit(x, 2)
>>> print(x1)
[[[0 1]]

 [[4 5]]

 [[8 9]]]
>>> print(x2)
[[[ 2  3]]

 [[ 6  7]]

 [[10 11]]]

另请参阅

参数:
  • ary (ArrayLike)

  • indices_or_sections (int | Sequence[int] | ArrayLike)

返回类型:

list[Array]