jax.numpy.spacing#

jax.numpy.spacing(x, /)[源代码]#

返回 x 和下一个相邻数字之间的间距。

numpy.spacing() 的 JAX 实现。

参数:

x (ArrayLike) – 实值数组。整数或布尔类型将被转换为浮点数。

返回:

x 形状相同的数组,其中包含 x 中每个条目与其最接近的相邻值之间的间隔。

返回类型:

数组

另请参阅

示例

>>> x = jnp.array([0.0, 0.25, 0.5, 0.75, 1.0], dtype='float32')
>>> jnp.spacing(x)
Array([1.4012985e-45, 2.9802322e-08, 5.9604645e-08, 5.9604645e-08,
      1.1920929e-07], dtype=float32)

对于 x = 1, 间隔等于 jax.numpy.finfo 给出的 eps

>>> x = jnp.float32(1)
>>> jnp.spacing(x) == jnp.finfo(x.dtype).eps
Array(True, dtype=bool)