jax.numpy.apply_along_axis#
- jax.numpy.apply_along_axis(func1d, axis, arr, *args, **kwargs)[source]#
沿给定轴将函数应用于一维切片。
numpy.apply_along_axis()
的 LAX 后端实现。原始文档字符串如下。
执行 func1d(a, *args, **kwargs),其中 func1d 对一维数组进行操作,而 a 是 arr 沿 axis 的一维切片。
这等效于(但比)以下使用 ndindex 和 s_ 的方法,它将
ii
、jj
和kk
分别设置为索引元组Ni, Nk = a.shape[:axis], a.shape[axis+1:] for ii in ndindex(Ni): for kk in ndindex(Nk): f = func1d(arr[ii + s_[:,] + kk]) Nj = f.shape for jj in ndindex(Nj): out[ii + jj + kk] = f[jj]
等效地,消除内部循环,这可以表示为
Ni, Nk = a.shape[:axis], a.shape[axis+1:] for ii in ndindex(Ni): for kk in ndindex(Nk): out[ii + s_[...,] + kk] = func1d(arr[ii + s_[:,] + kk])
- 参数:
func1d (function (M,) -> (Nj...)) – 此函数应接受一维数组。它应用于 arr 沿指定轴的一维切片。
axis (integer) – arr 切片的轴。
arr (ndarray (Ni..., M, Nk...)) – 输入数组。
args (any) – func1d 的附加参数。
kwargs (any) – func1d 的附加命名参数。
- 返回值:
out – 输出数组。 out 的形状与 arr 的形状相同,除了 axis 维度。此轴将被移除,并替换为等于 func1d 的返回值形状的新维度。因此,如果 func1d 返回标量,则 out 的维度将比 arr 少一个。
- 返回类型:
ndarray (Ni…, Nj…, Nk…)