jax.numpy.rot90#

jax.numpy.rot90(m, k=1, axes=(0, 1))[源代码]#

在由轴指定的平面内将数组逆时针旋转 90 度。

JAX 实现的 numpy.rot90()

参数:
  • m (类似数组) – 输入数组。 必须有 m.ndim >= 2

  • k (int) – int,可选,默认值为 1。指定数组旋转的次数。 对于 k 的负值,数组将按顺时针方向旋转。

  • axes (tuple[int, int]) – 包含 2 个整数的元组,可选,默认值为 (0, 1)。 这些轴定义了数组旋转所在的平面。 两个轴必须不同。

返回值:

一个包含输入数组副本的数组,m 旋转了 90 度。

返回类型:

数组

另请参阅

示例

>>> m = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> jnp.rot90(m)
Array([[3, 6],
       [2, 5],
       [1, 4]], dtype=int32)
>>> jnp.rot90(m, k=2)
Array([[6, 5, 4],
       [3, 2, 1]], dtype=int32)

jnp.rot90(m, k=1, axes=(1, 0)) 等价于 jnp.rot90(m, k=-1, axes(0,1))

>>> jnp.rot90(m, axes=(1, 0))
Array([[4, 1],
       [5, 2],
       [6, 3]], dtype=int32)
>>> jnp.rot90(m, k=-1, axes=(0, 1))
Array([[4, 1],
       [5, 2],
       [6, 3]], dtype=int32)

当输入数组的 ndim>2

>>> m1 = jnp.array([[[1, 2, 3],
...                  [4, 5, 6]],
...                 [[7, 8, 9],
...                  [10, 11, 12]]])
>>> jnp.rot90(m1, k=1, axes=(2, 1))
Array([[[ 4,  1],
        [ 5,  2],
        [ 6,  3]],

       [[10,  7],
        [11,  8],
        [12,  9]]], dtype=int32)