jax.lax.broadcast_to_rank

jax.lax.broadcast_to_rank#

jax.lax.broadcast_to_rank(x, rank)[source]#

x 的前面添加 1 的维度,使其具有 rank 个维度。

参数:
  • x (ArrayLike)

  • rank (int)

返回值类型:

Array