jax.numpy.promote_types#
- jax.numpy.promote_types(a, b)[源代码]#
返回二元运算应将其参数强制转换成的类型。
JAX 的
numpy.promote_types()
的实现。有关 JAX 的类型提升语义的详细信息,请参阅 类型提升语义。- 参数:
a (DTypeLike) – 一个
numpy.dtype
或一个 dtype 说明符。b (DTypeLike) – 一个
numpy.dtype
或一个 dtype 说明符。
- 返回:
一个
numpy.dtype
对象。- 返回类型:
DType
示例
类型说明符可以是字符串、dtypes 或标量类型,返回值始终是一个 dtype
>>> jnp.promote_types('int32', 'float32') # strings dtype('float32') >>> jnp.promote_types(jnp.dtype('int32'), jnp.dtype('float32')) # dtypes dtype('float32') >>> jnp.promote_types(jnp.int32, jnp.float32) # scalar types dtype('float32')
内置标量类型(
int
、float
或complex
)被视为弱类型,不会更改强类型对应项的位宽(请参阅类型提升语义中的讨论)。>>> jnp.promote_types('uint8', int) dtype('uint8') >>> jnp.promote_types('float16', float) dtype('float16')
这与 NumPy 版本的此函数不同,NumPy 将内置标量类型视为等效于 64 位类型。
>>> import numpy >>> numpy.promote_types('uint8', int) dtype('int64') >>> numpy.promote_types('float16', float) dtype('float64')