jax.numpy.isscalar#
- jax.numpy.isscalar(element)[源代码]#
如果输入是标量,则返回 True。
JAX 实现的
numpy.isscalar()
。JAX 的实现与 NumPy 的不同之处在于,它将零维数组视为标量;有关更多详细信息,请参阅下面的注意。- 参数:
element (Any) – 要检查的输入对象;任何类型都是有效的输入。
- 返回:
如果
element
是标量值或零维类数组对象,则返回 True,否则返回 False。- 返回类型:
注意
JAX 和 NumPy 在标量值的表示方式上有所不同。NumPy 具有特殊的标量对象(例如
np.int32(0)
),它们与零维数组(例如np.array(0)
)不同,numpy.isscalar()
对前者返回True
,对后者返回False
。JAX 不定义特殊的标量对象,而是将标量表示为零维数组。因此,
jax.numpy.isscalar()
对于标量对象(例如0.0
或np.float32(0.0)
)和零维类数组对象(例如jnp.array(0.0)
、np.array(0.0)
)都返回True
。isscalar
中采用不同约定的一个原因是保持 JIT 不变性:即当函数经过 JIT 编译后,其结果不应发生变化。由于标量输入在 JIT 边界会被转换为零维 JAX 数组,numpy.isscalar()
的语义会在 JIT 下改变结果。>>> np.isscalar(1.0) True >>> jax.jit(np.isscalar)(1.0) Array(False, dtype=bool)
通过将零维数组视为标量,
jax.numpy.isscalar()
避免了这个问题。>>> jnp.isscalar(1.0) True >>> jax.jit(jnp.isscalar)(1.0) Array(True, dtype=bool)
示例
在 JAX 中,标量和零维类数组对象都被视为标量。
>>> jnp.isscalar(1.0) True >>> jnp.isscalar(1 + 1j) True >>> jnp.isscalar(jnp.array(1)) # zero-dimensional JAX array True >>> jnp.isscalar(jnp.int32(1)) # JAX scalar constructor True >>> jnp.isscalar(np.array(1.0)) # zero-dimensional NumPy array True >>> jnp.isscalar(np.int32(1)) # NumPy scalar type True
具有一个或多个维度的数组不被视为标量。
>>> jnp.isscalar(jnp.array([1])) False >>> jnp.isscalar(np.array([1])) False
与
numpy.isscalar()
进行比较,它对标量类型的对象返回True
,对所有数组(即使是零维数组)返回False
。>>> np.isscalar(np.int32(1)) # scalar object True >>> np.isscalar(np.array(1)) # zero-dimensional array False
在 JAX 中,与 NumPy 一样,非类数组对象不被视为标量。
>>> jnp.isscalar(None) False >>> jnp.isscalar([1]) False >>> jnp.isscalar(tuple()) False >>> jnp.isscalar(slice(10)) False