jax.numpy.result_type#
- jax.numpy.result_type(*args)[源代码]#
返回将 JAX 提升规则应用于输入的结果。
JAX 实现的
numpy.result_type()
。JAX 的数据类型提升行为在 类型提升语义 中描述。
- 参数:
args (Any) – 一个或多个数组或类似 dtype 的对象。
- 返回:
一个
numpy.dtype
实例,表示输入的类型提升结果。- 返回类型:
DType
示例
输入可以是 dtype 说明符
>>> jnp.result_type('int32', 'float32') dtype('float32') >>> jnp.result_type(np.uint16, np.dtype('int32')) dtype('int32')
输入也可以是标量或数组
>>> jnp.result_type(1.0, jnp.bfloat16(2)) dtype(bfloat16) >>> jnp.result_type(jnp.arange(4), jnp.zeros(4)) dtype('float32')
请注意,结果类型将基于
jax_enable_x64
配置标志的状态进行规范化,这意味着 64 位类型可能会被向下转换为 32 位>>> jnp.result_type('float64') dtype('float32')
有关 64 位值的详细信息,请参阅 Sharp bits - 双精度