jax.numpy.issubdtype#

jax.numpy.issubdtype(arg1, arg2)[源代码]#

如果 arg1 在类型层次结构中等于或低于 arg2,则返回 True。

JAX 实现的 numpy.issubdtype()

JAX 实现的主要区别在于它正确处理了 dtype 扩展,例如 bfloat16

参数:
  • arg1 (DTypeLike) – 类似 dtype 的对象。在典型用法中,这将是一个 dtype 说明符,例如 "float32"(即一个字符串),np.dtype('int32')(即 numpy.dtype 的实例),jnp.complex64(即 JAX 标量构造函数),或 np.uint8(即 NumPy 标量类型)。

  • arg2 (DTypeLike) – 类似 dtype 的对象。在典型用法中,这将是一个通用标量类型,例如 jnp.integerjnp.floatingjnp.complexfloating

返回:

如果 arg1 表示的 dtype 在类型层次结构中等于或低于 arg2,则为 True。

返回类型:

bool

另请参阅

示例

>>> jnp.issubdtype('uint32', jnp.unsignedinteger)
True
>>> jnp.issubdtype(np.int32, jnp.integer)
True
>>> jnp.issubdtype(jnp.bfloat16, jnp.floating)
True
>>> jnp.issubdtype(np.dtype('complex64'), jnp.complexfloating)
True
>>> jnp.issubdtype('complex64', jnp.integer)
False

请注意,虽然这与 numpy.issubdtype() 非常相似,但在 JAX 自定义浮点类型的情况下,这些结果有所不同

>>> np.issubdtype('bfloat16', np.floating)
False
>>> jnp.issubdtype('bfloat16', jnp.floating)
True