jax.numpy.result_type#

jax.numpy.result_type(*args)[源代码]#

返回将 JAX 提升规则应用于输入的结果。

JAX 实现的 numpy.result_type()

JAX 的数据类型提升行为在 类型提升语义 中描述。

参数:

args (Any) – 一个或多个数组或类似 dtype 的对象。

返回:

一个 numpy.dtype 实例,表示输入的类型提升结果。

返回类型:

DType

示例

输入可以是 dtype 说明符

>>> jnp.result_type('int32', 'float32')
dtype('float32')
>>> jnp.result_type(np.uint16, np.dtype('int32'))
dtype('int32')

输入也可以是标量或数组

>>> jnp.result_type(1.0, jnp.bfloat16(2))
dtype(bfloat16)
>>> jnp.result_type(jnp.arange(4), jnp.zeros(4))
dtype('float32')

请注意,结果类型将基于 jax_enable_x64 配置标志的状态进行规范化,这意味着 64 位类型可能会被向下转换为 32 位

>>> jnp.result_type('float64')
dtype('float32')

有关 64 位值的详细信息,请参阅 Sharp bits - 双精度