jax.numpy.isscalar#

jax.numpy.isscalar(element)[源代码]#

如果输入是标量,则返回 True。

JAX 实现的 numpy.isscalar()。JAX 的实现与 NumPy 的不同之处在于,它将零维数组视为标量;有关更多详细信息,请参阅下面的注意

参数:

element (Any) – 要检查的输入对象;任何类型都是有效的输入。

返回:

如果 element 是标量值或零维类数组对象,则返回 True,否则返回 False。

返回类型:

bool

注意

JAX 和 NumPy 在标量值的表示方式上有所不同。NumPy 具有特殊的标量对象(例如 np.int32(0)),它们与零维数组(例如 np.array(0))不同,numpy.isscalar() 对前者返回 True,对后者返回 False

JAX 不定义特殊的标量对象,而是将标量表示为零维数组。因此,jax.numpy.isscalar() 对于标量对象(例如 0.0np.float32(0.0))和零维类数组对象(例如 jnp.array(0.0)np.array(0.0))都返回 True

isscalar 中采用不同约定的一个原因是保持 JIT 不变性:即当函数经过 JIT 编译后,其结果不应发生变化。由于标量输入在 JIT 边界会被转换为零维 JAX 数组,numpy.isscalar() 的语义会在 JIT 下改变结果。

>>> np.isscalar(1.0)
True
>>> jax.jit(np.isscalar)(1.0)
Array(False, dtype=bool)

通过将零维数组视为标量,jax.numpy.isscalar() 避免了这个问题。

>>> jnp.isscalar(1.0)
True
>>> jax.jit(jnp.isscalar)(1.0)
Array(True, dtype=bool)

示例

在 JAX 中,标量和零维类数组对象都被视为标量。

>>> jnp.isscalar(1.0)
True
>>> jnp.isscalar(1 + 1j)
True
>>> jnp.isscalar(jnp.array(1))  # zero-dimensional JAX array
True
>>> jnp.isscalar(jnp.int32(1))  # JAX scalar constructor
True
>>> jnp.isscalar(np.array(1.0))  # zero-dimensional NumPy array
True
>>> jnp.isscalar(np.int32(1))  # NumPy scalar type
True

具有一个或多个维度的数组不被视为标量。

>>> jnp.isscalar(jnp.array([1]))
False
>>> jnp.isscalar(np.array([1]))
False

numpy.isscalar() 进行比较,它对标量类型的对象返回 True,对所有数组(即使是零维数组)返回 False

>>> np.isscalar(np.int32(1))  # scalar object
True
>>> np.isscalar(np.array(1))  # zero-dimensional array
False

在 JAX 中,与 NumPy 一样,非类数组对象不被视为标量。

>>> jnp.isscalar(None)
False
>>> jnp.isscalar([1])
False
>>> jnp.isscalar(tuple())
False
>>> jnp.isscalar(slice(10))
False