jax.numpy.round_

内容

jax.numpy.round_#

jax.numpy.round_(a, decimals=0, out=None)[source]#

将输入均匀四舍五入到给定的十进制位数。

JAX 实现 numpy.round().

参数:
  • a (ArrayLike) – 输入数组或标量。

  • decimals (int) – int,默认为 0。输入需要四舍五入到的十进制位数。它必须是静态指定的。对于 decimals < 0 未实现。

  • out (None) – JAX 未使用。

返回值:

一个数组,包含舍入到指定 decimals 的值,其形状和数据类型与 a 相同。

返回类型:

数组

注意

jnp.round 将正好位于舍入十进制值之间的值舍入到最接近的偶数整数。

另请参阅

示例

>>> x = jnp.array([1.532, 3.267, 6.149])
>>> jnp.round(x)
Array([2., 3., 6.], dtype=float32)
>>> jnp.round(x, decimals=2)
Array([1.53, 3.27, 6.15], dtype=float32)

对于正好位于舍入值之间的值

>>> x1 = jnp.array([10.5, 21.5, 12.5, 31.5])
>>> jnp.round(x1)
Array([10., 22., 12., 32.], dtype=float32)