jax.numpy.issubdtype#
- jax.numpy.issubdtype(arg1, arg2)[源代码]#
如果 arg1 在类型层次结构中等于或低于 arg2,则返回 True。
JAX 实现的
numpy.issubdtype()
。JAX 实现的主要区别在于它正确处理了 dtype 扩展,例如
bfloat16
。- 参数:
arg1 (DTypeLike) – 类似 dtype 的对象。在典型用法中,这将是一个 dtype 说明符,例如
"float32"
(即一个字符串),np.dtype('int32')
(即numpy.dtype
的实例),jnp.complex64
(即 JAX 标量构造函数),或np.uint8
(即 NumPy 标量类型)。arg2 (DTypeLike) – 类似 dtype 的对象。在典型用法中,这将是一个通用标量类型,例如
jnp.integer
、jnp.floating
或jnp.complexfloating
。
- 返回:
如果 arg1 表示的 dtype 在类型层次结构中等于或低于 arg2,则为 True。
- 返回类型:
另请参阅
jax.numpy.isdtype()
:与数组 API 标准对齐的类似函数。
示例
>>> jnp.issubdtype('uint32', jnp.unsignedinteger) True >>> jnp.issubdtype(np.int32, jnp.integer) True >>> jnp.issubdtype(jnp.bfloat16, jnp.floating) True >>> jnp.issubdtype(np.dtype('complex64'), jnp.complexfloating) True >>> jnp.issubdtype('complex64', jnp.integer) False
请注意,虽然这与
numpy.issubdtype()
非常相似,但在 JAX 自定义浮点类型的情况下,这些结果有所不同>>> np.issubdtype('bfloat16', np.floating) False >>> jnp.issubdtype('bfloat16', jnp.floating) True