jax.numpy.nan_to_num#

jax.numpy.nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None)[源代码]#

替换数组中的 NaN 和无限条目。

numpy.nan_to_num() 的 JAX 实现。

参数
  • x (ArrayLike) – 要替换的值的数组。如果它没有非精确的 dtype,它将保持不变地返回。

  • copy (bool) – JAX 未使用

  • nan (ArrayLike) – 替换 NaN 条目的值。默认为 0.0。

  • posinf (ArrayLike | None) – 替换正无穷大条目的值。默认为最大可表示值。

  • neginf (ArrayLike | None) – 替换正无穷大条目的值。默认为最小可表示值。

返回

带有请求替换的 x 的副本。

返回类型

数组

另请参阅

示例

>>> x = jnp.array([0, jnp.nan, 1, jnp.inf, 2, -jnp.inf])

默认替换值

>>> jnp.nan_to_num(x)
Array([ 0.0000000e+00,  0.0000000e+00,  1.0000000e+00,  3.4028235e+38,
        2.0000000e+00, -3.4028235e+38], dtype=float32)

覆盖 -inf+inf 的替换

>>> jnp.nan_to_num(x, posinf=999, neginf=-999)
Array([   0.,    0.,    1.,  999.,    2., -999.], dtype=float32)

如果您只想替换 NaN 值,而保留 inf 值不变,那么使用 where()jax.numpy.isnan() 是更好的选择

>>> jnp.where(jnp.isnan(x), 0, x)
Array([  0.,   0.,   1.,  inf,   2., -inf], dtype=float32)