jax.numpy.count_nonzero#
- jax.numpy.count_nonzero(a, axis=None, keepdims=False)[source]#
返回给定轴上非零元素的数量。
JAX 实现
numpy.count_nonzero()
.- 参数:
**a** (ArrayLike) – 输入数组。
**axis** (Axis) – 可选,int 或 int 序列,默认值=None。计算非零数量的轴。如果为 None,则在扁平化数组中计算。
**keepdims** (bool) – 布尔值,默认值=False。如果为 True,则保留结果中的缩减轴,其大小为 1。
- 返回值:
一个数组,其中包含输入沿指定轴的非零元素数量。
- 返回类型:
示例
默认情况下,
jnp.count_nonzero
计算所有轴上的非零值。>>> x = jnp.array([[1, 0, 0, 0], ... [0, 0, 1, 0], ... [1, 1, 1, 0]]) >>> jnp.count_nonzero(x) Array(5, dtype=int32)
如果
axis=1
,则沿轴 1 计算。>>> jnp.count_nonzero(x, axis=1) Array([1, 1, 3], dtype=int32)
为了保留输入的维度,可以设置
keepdims=True
。>>> jnp.count_nonzero(x, axis=1, keepdims=True) Array([[1], [1], [3]], dtype=int32)