jax.numpy.nan_to_num#
- jax.numpy.nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None)[源代码]#
替换数组中的 NaN 和无限条目。
numpy.nan_to_num()
的 JAX 实现。- 参数:
x (ArrayLike) – 要替换的值的数组。如果它没有非精确的 dtype,它将保持不变地返回。
copy (bool) – JAX 未使用
nan (ArrayLike) – 替换 NaN 条目的值。默认为 0.0。
posinf (ArrayLike | None) – 替换正无穷大条目的值。默认为最大可表示值。
neginf (ArrayLike | None) – 替换正无穷大条目的值。默认为最小可表示值。
- 返回:
带有请求替换的
x
的副本。- 返回类型:
另请参阅
jax.numpy.isnan()
:返回数组中包含 NaN 的位置为 Truejax.numpy.isposinf()
:返回数组中包含 +inf 的位置为 Truejax.numpy.isneginf()
:返回数组中包含 -inf 的位置为 True
示例
>>> x = jnp.array([0, jnp.nan, 1, jnp.inf, 2, -jnp.inf])
默认替换值
>>> jnp.nan_to_num(x) Array([ 0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 3.4028235e+38, 2.0000000e+00, -3.4028235e+38], dtype=float32)
覆盖
-inf
和+inf
的替换>>> jnp.nan_to_num(x, posinf=999, neginf=-999) Array([ 0., 0., 1., 999., 2., -999.], dtype=float32)
如果您只想替换 NaN 值,而保留
inf
值不变,那么使用where()
与jax.numpy.isnan()
是更好的选择>>> jnp.where(jnp.isnan(x), 0, x) Array([ 0., 0., 1., inf, 2., -inf], dtype=float32)