jax.linear_transpose#
- jax.linear_transpose(fun, *primals, reduce_axes=())[源代码]#
转置一个承诺是线性的函数。
对于线性函数,此转换等效于
vjp()
,但避免了计算前向传递的开销。转置函数的输出将始终具有与
primals
完全相同的数据类型,即使某些值被截断(例如,从复数到浮点数,或从 float64 到 float32)。 要避免截断,请在primals
中使用与转置函数所需输出的完整范围匹配的数据类型。 不支持整数数据类型。- 参数:
fun (Callable) – 要转置的线性函数。
*primals – 位置参数元组,包含数组、标量或这些类型的(嵌套)标准Python容器(元组、列表、字典、具名元组,即pytrees),用于评估
fun(*primals)
的形状/数据类型。这些参数可以是实数标量/ndarrays,但这不是必需的:只访问shape
和dtype
属性。请参阅下面的示例。(请注意,鸭子类型对象不能是具名元组,因为它们被视为标准 Python 容器。)
- 返回:
一个可调用对象,计算
fun
的转置。此函数的有效输入必须具有与fun(*primals)
的结果相同的形状/数据类型/结构。输出将是一个元组,其形状/数据类型/结构与primals
相同。- 返回类型:
可调用对象
>>> import jax >>> import types >>> >>> f = lambda x, y: 0.5 * x - 0.5 * y >>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32)) >>> f_transpose = jax.linear_transpose(f, scalar, scalar) >>> f_transpose(1.0) (Array(0.5, dtype=float32), Array(-0.5, dtype=float32))