jax.numpy.trunc#
- jax.numpy.trunc(x)[源代码]#
将输入舍入到最接近零的整数。
numpy.trunc()
的 JAX 实现。- 参数:
x (类数组) – 输入数组或标量。
- 返回:
一个与
x
形状和 dtype 相同的数组,其中包含舍入后的值。- 返回类型:
另请参阅
jax.numpy.fix()
: 将输入舍入到最接近零的整数。jax.numpy.ceil()
: 将输入向上舍入到最接近的整数。jax.numpy.floor()
: 将输入向下舍入到最接近的整数。
示例
>>> key = jax.random.key(42) >>> x = jax.random.uniform(key, (3, 3), minval=-10, maxval=10) >>> with jnp.printoptions(precision=2, suppress=True): ... print(x) [[-0.23 3.6 2.33] [ 1.22 -0.99 1.72] [-8.5 5.5 3.98]] >>> jnp.trunc(x) Array([[-0., 3., 2.], [ 1., -0., 1.], [-8., 5., 3.]], dtype=float32)