jax.numpy.rint

内容

jax.numpy.rint#

jax.numpy.rint(x, /)[source]#

将 x 的元素四舍五入到最接近的整数

JAX 实现 numpy.rint.

参数:

x (ArrayLike) – 输入数组

返回值:

一个包含 x 的四舍五入元素的类似数组的对象。始终提升为不精确。

返回值类型:

Array

注意

如果 x 的元素正好是中间值,例如 0.51.5,rint 将四舍五入到最接近的偶数整数。

示例

>>> x1 = jnp.array([5, 4, 7])
>>> jnp.rint(x1)
Array([5., 4., 7.], dtype=float32)
>>> x2 = jnp.array([-2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5, 4.5])
>>> jnp.rint(x2)
Array([-2., -2., -0.,  0.,  2.,  2.,  4.,  4.], dtype=float32)
>>> x3 = jnp.array([-2.5+3.5j, 4.5-0.5j])
>>> jnp.rint(x3)
Array([-2.+4.j,  4.-0.j], dtype=complex64)