jax.numpy.transpose#
- jax.numpy.transpose(a, axes=None)[source]#
返回 N 维数组的转置版本。
JAX 实现
numpy.transpose()
,使用jax.lax.transpose()
实现。- 参数:
a (ArrayLike) – 输入数组
axes (Sequence[int] | None | None) – 可选地使用长度为 a.ndim 的整数序列
i
指定排列,满足0 <= i < a.ndim
。默认值为range(a.ndim)[::-1]
,即反转所有轴的顺序。
- 返回值:
数组的转置副本。
- 返回类型:
参见
jax.Array.transpose()
: 通过Array
方法实现的等效函数。jax.Array.T
: 通过Array
属性实现的等效函数。jax.numpy.matrix_transpose()
: 对数组的最后两个轴进行转置。这适用于处理批量二维矩阵。jax.numpy.swapaxes()
: 交换数组中的任意两个轴。jax.numpy.moveaxis()
: 将一个轴移动到数组中的另一个位置。
注意
与
numpy.transpose()
不同,jax.numpy.transpose()
将返回输入数组的副本,而不是视图。但是,在 JIT 下,编译器会尽可能优化掉这些副本,因此在实践中不会影响性能。示例
对于一维数组,转置是恒等变换
>>> x = jnp.array([1, 2, 3, 4]) >>> jnp.transpose(x) Array([1, 2, 3, 4], dtype=int32)
对于二维数组,转置是矩阵转置
>>> x = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.transpose(x) Array([[1, 3], [2, 4]], dtype=int32)
对于 N 维数组,转置会反转轴的顺序
>>> x = jnp.zeros(shape=(3, 4, 5)) >>> jnp.transpose(x).shape (5, 4, 3)
可以指定
axes
参数来改变此默认行为>>> jnp.transpose(x, (0, 2, 1)).shape (3, 5, 4)
由于交换最后两个轴是一个常见的操作,可以通过其自己的 API
jax.numpy.matrix_transpose()
来实现。>>> jnp.matrix_transpose(x).shape (3, 5, 4)
为了方便起见,也可以使用
jax.Array.transpose()
方法或jax.Array.T
属性来执行转置。>>> x = jnp.array([[1, 2], ... [3, 4]]) >>> x.transpose() Array([[1, 3], [2, 4]], dtype=int32) >>> x.T Array([[1, 3], [2, 4]], dtype=int32)