jax.numpy.nan_to_num#

jax.numpy.nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None)[源代码]#

替换数组中的 NaN 和无限值。

JAX 实现的 numpy.nan_to_num()

参数:
  • x (ArrayLike) – 要替换的值的数组。如果它没有非精确的数据类型,则会原样返回。

  • copy (bool) – JAX 未使用

  • nan (ArrayLike) – 替换 NaN 条目的值。默认为 0.0。

  • posinf (ArrayLike | None) – 替换正无穷条目的值。默认为最大可表示值。

  • neginf (ArrayLike | None) – 替换负无穷条目的值。默认为最小可表示值。

返回:

一个带有请求替换的 x 的副本。

返回类型:

数组

另请参阅

示例

>>> x = jnp.array([0, jnp.nan, 1, jnp.inf, 2, -jnp.inf])

默认替换值

>>> jnp.nan_to_num(x)
Array([ 0.0000000e+00,  0.0000000e+00,  1.0000000e+00,  3.4028235e+38,
        2.0000000e+00, -3.4028235e+38], dtype=float32)

覆盖 -inf+inf 的替换值

>>> jnp.nan_to_num(x, posinf=999, neginf=-999)
Array([   0.,    0.,    1.,  999.,    2., -999.], dtype=float32)

如果您只想替换 NaN 值,而保持 inf 值不变,则使用 where()jax.numpy.isnan() 是更好的选择

>>> jnp.where(jnp.isnan(x), 0, x)
Array([  0.,   0.,   1.,  inf,   2., -inf], dtype=float32)