jax.lax.full#

jax.lax.full(shape, fill_value, dtype=None, *, sharding=None)[源代码]#

返回一个用 fill_value 填充的 shape 数组。

参数:
  • shape (形状) – 整数序列,描述输出数组的形状。

  • fill_value (类数组) – 用于填充新数组的值。

  • dtype (数据类型 | None | None) – 输出数组的类型,或者 None。如果不是 Nonefill_value 将被强制转换为 dtype

  • sharding (分片 | None | None) – 结果数组的可选分片规范,注意,分片目前在 JIT 模式下将被忽略,未来可能会发生改变。

返回类型:

数组