jax.numpy.apply_along_axis#
- jax.numpy.apply_along_axis(func1d, axis, arr, *args, **kwargs)[源代码]#
沿轴将函数应用于 1D 数组切片。
JAX 实现的
numpy.apply_along_axis()
。NumPy 迭代地实现此功能,而 JAX 通过jax.vmap()
实现此功能,因此func1d
必须与vmap
兼容。- 参数:
func1d (Callable) – 一个可调用函数,签名如下
func1d(arr, /, *args, **kwargs)
,其中*args
和**kwargs
是传递给apply_along_axis()
的附加位置和关键字参数。axis (int) – 应用函数的整数轴。
arr (ArrayLike) – 应用函数的数组。
args – 传递给
func1d
的附加位置和关键字参数。kwargs – 传递给
func1d
的附加位置和关键字参数。
- 返回:
沿指定轴应用
func1d
的结果。- 返回类型:
另请参阅
jax.vmap()
: 一种创建函数向量化版本的更直接方法。jax.numpy.apply_over_axes()
: 在多个轴上重复应用函数。jax.numpy.vectorize()
: 创建函数的向量化版本。
示例
一个二维的简单示例,其中函数按行或按列应用
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> def func1d(x): ... return jnp.sum(x ** 2) >>> jnp.apply_along_axis(func1d, 0, x) Array([17, 29, 45], dtype=int32) >>> jnp.apply_along_axis(func1d, 1, x) Array([14, 77], dtype=int32)
对于 2D 输入,可以使用
jax.vmap()
等效地表达,但请注意,vmap 指定的是映射轴而不是应用轴>>> jax.vmap(func1d, in_axes=1)(x) # same as applying along axis 0 Array([17, 29, 45], dtype=int32) >>> jax.vmap(func1d, in_axes=0)(x) # same as applying along axis 1 Array([14, 77], dtype=int32)
对于 3D 输入,
apply_along_axis()
等价于在两个维度上映射>>> x_3d = jnp.arange(24).reshape(2, 3, 4) >>> jnp.apply_along_axis(func1d, 2, x_3d) Array([[ 14, 126, 366], [ 734, 1230, 1854]], dtype=int32) >>> jax.vmap(jax.vmap(func1d))(x_3d) Array([[ 14, 126, 366], [ 734, 1230, 1854]], dtype=int32)
应用的函数还可以接受任意位置或关键字参数,这些参数应直接作为额外的参数传递给
apply_along_axis()
>>> def func1d(x, exponent): ... return jnp.sum(x ** exponent) >>> jnp.apply_along_axis(func1d, 0, x, exponent=3) Array([ 65, 133, 243], dtype=int32)