jax.lax.conv_dimension_numbers

jax.lax.conv_dimension_numbers#

jax.lax.conv_dimension_numbers(lhs_shape, rhs_shape, dimension_numbers)[source]#

将卷积 dimension_numbers 转换为 ConvDimensionNumbers

参数:
  • lhs_shape – 非负整数元组,卷积输入的形状。

  • rhs_shape – 非负整数元组,卷积核的形状。

  • dimension_numbers – None 或字符串元组/列表,或 ConvDimensionNumbers 对象,遵循 xla_client.py 中的卷积维度编号规范格式。

返回值:

一个 ConvDimensionNumbers 对象,以 lax 函数使用的规范形式表示 dimension_numbers

返回类型:

ConvDimensionNumbers