jax.numpy.linspace

内容

jax.numpy.linspace#

jax.numpy.linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, retstep: Literal[False] = False, dtype: DTypeLike | None = None, axis: int = 0, *, device: xc.Device | Sharding | None = None) Array[source]#
jax.numpy.linspace(start: ArrayLike, stop: ArrayLike, num: int, endpoint: bool, retstep: Literal[True], dtype: DTypeLike | None = None, axis: int = 0, *, device: xc.Device | Sharding | None = None) tuple[Array, Array]
jax.numpy.linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, *, retstep: Literal[True], dtype: DTypeLike | None = None, axis: int = 0, device: xc.Device | Sharding | None = None) tuple[Array, Array]
jax.numpy.linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, retstep: bool = False, dtype: DTypeLike | None = None, axis: int = 0, *, device: xc.Device | Sharding | None = None) Array | tuple[Array, Array]

在区间内返回均匀间隔的数字。

numpy.linspace() 的 JAX 实现。

参数:
  • start – 标量或起始值的数组。

  • stop – 标量或停止值的数组。

  • num – 要生成的数值个数。默认值:50。

  • endpoint – 如果为 True(默认值),则在结果中包含 stop 值。如果为 False,则排除 stop 值。

  • retstep – 如果为 True,则返回一个 (result, step) 元组,其中 stepresult 中相邻值之间的间隔。

  • axis – 生成 linspace 的整数轴。默认为零。

  • device – 可选的 DeviceSharding,创建的数组将被提交到该设备或分片。

返回值:

  • values 是一个从 startstop 的均匀间隔值的数组

  • step 是相邻值之间的间隔。

返回类型:

一个数组 values,或者如果 retstep 为 True,则为一个元组 (values, step),其中

另请参阅

示例

0 到 10 之间 5 个值的列表

>>> jnp.linspace(0, 10, 5)
Array([ 0. ,  2.5,  5. ,  7.5, 10. ], dtype=float32)

0 到 10 之间 8 个值的列表,不包括端点

>>> jnp.linspace(0, 10, 8, endpoint=False)
Array([0.  , 1.25, 2.5 , 3.75, 5.  , 6.25, 7.5 , 8.75], dtype=float32)

值列表及其之间的步长

>>> vals, step = jnp.linspace(0, 10, 9, retstep=True)
>>> vals
Array([ 0.  ,  1.25,  2.5 ,  3.75,  5.  ,  6.25,  7.5 ,  8.75, 10.  ],      dtype=float32)
>>> step
Array(1.25, dtype=float32)

多维 linspace

>>> start = jnp.array([0, 5])
>>> stop = jnp.array([5, 10])
>>> jnp.linspace(start, stop, 5)
Array([[ 0.  ,  5.  ],
       [ 1.25,  6.25],
       [ 2.5 ,  7.5 ],
       [ 3.75,  8.75],
       [ 5.  , 10.  ]], dtype=float32)