jax.numpy.promote_types#

jax.numpy.promote_types(a, b)[源代码]#

返回二元运算应将其参数强制转换成的类型。

JAX 的 numpy.promote_types() 的实现。有关 JAX 的类型提升语义的详细信息,请参阅 类型提升语义

参数
  • a (DTypeLike) – 一个 numpy.dtype 或一个 dtype 说明符。

  • b (DTypeLike) – 一个 numpy.dtype 或一个 dtype 说明符。

返回

一个 numpy.dtype 对象。

返回类型

DType

示例

类型说明符可以是字符串、dtypes 或标量类型,返回值始终是一个 dtype

>>> jnp.promote_types('int32', 'float32')  # strings
dtype('float32')
>>> jnp.promote_types(jnp.dtype('int32'), jnp.dtype('float32'))  # dtypes
dtype('float32')
>>> jnp.promote_types(jnp.int32, jnp.float32)  # scalar types
dtype('float32')

内置标量类型(intfloatcomplex)被视为弱类型,不会更改强类型对应项的位宽(请参阅类型提升语义中的讨论)。

>>> jnp.promote_types('uint8', int)
dtype('uint8')
>>> jnp.promote_types('float16', float)
dtype('float16')

这与 NumPy 版本的此函数不同,NumPy 将内置标量类型视为等效于 64 位类型。

>>> import numpy
>>> numpy.promote_types('uint8', int)
dtype('int64')
>>> numpy.promote_types('float16', float)
dtype('float64')