jax.numpy.transpose#
- jax.numpy.transpose(a, axes=None)[源代码]#
返回 N 维数组的转置版本。
JAX 实现的
numpy.transpose()
,根据jax.lax.transpose()
实现。- 参数:
a (ArrayLike) – 输入数组
axes (Sequence[int] | None | None) – 可选参数,用于指定排列顺序。它是一个长度为 a.ndim 的整数序列
i
,满足0 <= i < a.ndim
。默认为range(a.ndim)[::-1]
,即反转所有轴的顺序。
- 返回:
数组的转置副本。
- 返回类型:
另请参阅
jax.Array.transpose()
: 通过Array
方法实现的等效函数。jax.Array.T
: 通过Array
属性实现的等效函数。jax.numpy.matrix_transpose()
: 转置数组的最后两个轴。适用于处理批量 2D 矩阵。jax.numpy.swapaxes()
: 交换数组中的任意两个轴。jax.numpy.moveaxis()
: 将一个轴移动到数组中的另一个位置。
注意
与
numpy.transpose()
不同,jax.numpy.transpose()
将返回输入数组的副本,而不是视图。但是,在 JIT 编译下,编译器会在可能的情况下优化掉这些副本,因此在实践中不会对性能产生影响。示例
对于一维数组,转置是恒等变换
>>> x = jnp.array([1, 2, 3, 4]) >>> jnp.transpose(x) Array([1, 2, 3, 4], dtype=int32)
对于二维数组,转置是矩阵转置
>>> x = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.transpose(x) Array([[1, 3], [2, 4]], dtype=int32)
对于 N 维数组,转置会反转轴的顺序
>>> x = jnp.zeros(shape=(3, 4, 5)) >>> jnp.transpose(x).shape (5, 4, 3)
可以指定
axes
参数来更改此默认行为>>> jnp.transpose(x, (0, 2, 1)).shape (3, 5, 4)
由于交换最后两个轴是常见操作,因此可以通过其自己的 API
jax.numpy.matrix_transpose()
完成。>>> jnp.matrix_transpose(x).shape (3, 5, 4)
为方便起见,也可以使用
jax.Array.transpose()
方法或jax.Array.T
属性执行转置。>>> x = jnp.array([[1, 2], ... [3, 4]]) >>> x.transpose() Array([[1, 3], [2, 4]], dtype=int32) >>> x.T Array([[1, 3], [2, 4]], dtype=int32)