jax.numpy.count_nonzero#
- jax.numpy.count_nonzero(a, axis=None, keepdims=False)[源代码]#
返回给定轴上非零元素的数量。
numpy.count_nonzero()
的 JAX 实现。- 参数:
a (ArrayLike) – 输入数组。
axis (Axis) – 可选,整数或整数序列,默认=None。 计算非零数的轴。 如果为 None,则在展平的数组中计数。
keepdims (bool) – 布尔值,默认=False。 如果为 true,则缩减的轴将保留在结果中,大小为 1。
- 返回:
一个数组,其中包含输入指定轴上的非零元素数量。
- 返回类型:
示例
默认情况下,
jnp.count_nonzero
会计算所有轴上的非零值。>>> x = jnp.array([[1, 0, 0, 0], ... [0, 0, 1, 0], ... [1, 1, 1, 0]]) >>> jnp.count_nonzero(x) Array(5, dtype=int32)
如果
axis=1
,则沿轴 1 计算。>>> jnp.count_nonzero(x, axis=1) Array([1, 1, 3], dtype=int32)
要保留输入的维度,您可以设置
keepdims=True
。>>> jnp.count_nonzero(x, axis=1, keepdims=True) Array([[1], [1], [3]], dtype=int32)