jax.dtypes 模块#

bfloat16

bfloat16 浮点数值

canonicalize_dtype(dtype[, allow_extended_dtype])

根据 config.x64_enabled 将 dtype 转换为规范 dtype。

float0

与同名的标量类型和 dtype 对应的 DType 类。

issubdtype(a, b)

如果第一个参数是类型层次结构中较低/等于的类型代码,则返回 True。

prng_key()

PRNG Key dtype 的标量类。

result_type(*args[, return_weak_type_flag])

应用 JAX 参数 dtype 提升的便捷函数。

scalar_type_of(x)

返回与 JAX 值关联的标量类型。