jax.lax.broadcast_in_dim

jax.lax.broadcast_in_dim#

jax.lax.broadcast_in_dim(operand, shape, broadcast_dimensions)[source]#

封装 XLA 的 BroadcastInDim 运算符。

参数:
  • operand (ArrayLike) – 数组

  • shape (Shape) – 目标数组的形状

  • broadcast_dimensions (Sequence[int]) – 操作数形状的每个维度对应于目标形状中的哪个维度。也就是说,操作数的维度 i 成为结果的维度 broadcast_dimensions[i]。

返回:

包含结果的数组。

返回类型:

Array

另请参阅

jax.lax.broadcast:添加新前导维度的更简单的接口。