jax.numpy.flip

内容

jax.numpy.flip#

jax.numpy.flip(m, axis=None)[source]#

沿给定轴反转数组元素的顺序。

JAX 实现的 numpy.flip().

参数:
  • m (ArrayLike) – 数组。

  • axis (int | Sequence[int] | None | None) – 整数或整数序列。指定应反转数组元素的轴或轴的序列。默认为 None,沿所有轴反转。

返回值:

沿 axis 反转元素顺序的数组。

返回类型:

数组

另请参见

示例

>>> x1 = jnp.array([[1, 2],
...                 [3, 4]])
>>> jnp.flip(x1)
Array([[4, 3],
       [2, 1]], dtype=int32)

如果 axis 使用整数指定,则 jax.numpy.flip 仅沿该特定轴反转数组。

>>> jnp.flip(x1, axis=1)
Array([[2, 1],
       [4, 3]], dtype=int32)
>>> x2 = jnp.arange(1, 9).reshape(2, 2, 2)
>>> x2
Array([[[1, 2],
        [3, 4]],

       [[5, 6],
        [7, 8]]], dtype=int32)
>>> jnp.flip(x2)
Array([[[8, 7],
        [6, 5]],

       [[4, 3],
        [2, 1]]], dtype=int32)

如果 axis 使用整数序列指定,则 jax.numpy.flip 沿指定轴反转数组。

>>> jnp.flip(x2, axis=[1, 2])
Array([[[4, 3],
        [2, 1]],

       [[8, 7],
        [6, 5]]], dtype=int32)