jax.numpy.arange#
- jax.numpy.arange(start, stop=None, step=None, dtype=None, *, device=None)[源代码]#
创建一个等间隔数值的数组。
JAX 实现的
numpy.arange()
,基于jax.lax.iota()
实现。类似于 Python 的
range()
函数,可以使用以下几种不同的位置签名进行调用jnp.arange(stop)
: 生成从 0 到stop
的值,步长为 1。jnp.arange(start, stop)
: 生成从start
到stop
的值,步长为 1。jnp.arange(start, stop, step)
: 生成从start
到stop
的值,步长为step
。
与 Python 的
range()
函数一样,起始值是包含的,而结束值是不包含的。- 参数:
start (ArrayLike | DimSize) – 区间的起始值,包含在内。
stop (ArrayLike | DimSize | None | None) – 区间的可选结束值,不包含在内。如果未指定,则
(start, stop) = (0, start)
step (ArrayLike | None | None) – 区间的可选步长。默认值 = 1。
dtype (DTypeLike | None | None) – 返回数组的可选数据类型;如果未指定,则将通过 start、stop 和 step 的类型提升来确定。
device (xc.Device | Sharding | None | None) – (可选)
Device
或Sharding
,创建的数组将被提交到该设备或分片。
- 返回:
从
start
到stop
的等间隔数值数组,间隔为step
。- 返回类型:
注意
使用浮点数
step
参数的arange
可能会由于浮点数误差的累积而导致意外的结果,尤其是对于像float8_*
和bfloat16
这样的低精度数据类型。为了避免精度误差,请考虑生成一个整数范围,并将其缩放到所需的范围。例如,与其这样jnp.arange(-1, 1, 0.01, dtype='bfloat16')
不如生成一个整数序列,并对其进行缩放会更准确
(jnp.arange(-100, 100) * 0.01).astype('bfloat16')
示例
单参数版本仅指定
stop
值>>> jnp.arange(4) Array([0, 1, 2, 3], dtype=int32)
传递浮点数
stop
值会导致浮点数结果>>> jnp.arange(4.0) Array([0., 1., 2., 3.], dtype=float32)
双参数版本指定
start
和stop
,step=1
>>> jnp.arange(1, 6) Array([1, 2, 3, 4, 5], dtype=int32)
三参数版本指定
start
、stop
和step
>>> jnp.arange(0, 2, 0.5) Array([0. , 0.5, 1. , 1.5], dtype=float32)
另请参阅
jax.numpy.linspace()
: 生成固定数量的等间隔数值。jax.lax.iota()
: 直接在 XLA 中生成整数序列。