jax.lax.conv_dimension_numbers#
- jax.lax.conv_dimension_numbers(lhs_shape, rhs_shape, dimension_numbers)[source]#
将卷积 dimension_numbers 转换为 ConvDimensionNumbers。
- 参数:
lhs_shape – 非负整数元组,卷积输入的形状。
rhs_shape – 非负整数元组,卷积核的形状。
dimension_numbers – None 或字符串元组/列表,或 ConvDimensionNumbers 对象,遵循 xla_client.py 中的卷积维度编号规范格式。
- 返回值:
一个 ConvDimensionNumbers 对象,以 lax 函数使用的规范形式表示 dimension_numbers。
- 返回类型: