jax.numpy.apply_over_axes#
- jax.numpy.apply_over_axes(func, a, axes)[source]#
在多个轴上重复应用函数。
LAX 后端实现
numpy.apply_over_axes()
.原始文档字符串如下。
func 被调用为 res = func(a, axis),其中 axis 是 axes 的第一个元素。函数调用的结果 res 必须与 a 具有相同的维度或少一个维度。如果 res 比 a 少一个维度,则在 axis 之前插入一个维度。然后对 axes 中的每个轴重复调用 func,其中 res 作为第一个参数。
- 参数:
func (function) – 此函数必须接受两个参数,func(a, axis)。
a (array_like) – 输入数组。
axes (array_like) – 应用 func 的轴;元素必须是整数。
- 返回:
apply_over_axis – 输出数组。维度数量与a相同,但形状可能不同。这取决于func是否改变其输出相对于输入的形状。
- 返回类型:
ndarray