jax.numpy.apply_along_axis#

jax.numpy.apply_along_axis(func1d, axis, arr, *args, **kwargs)[源代码]#

沿轴将函数应用于 1D 数组切片。

JAX 实现的 numpy.apply_along_axis()。NumPy 迭代地实现此功能,而 JAX 通过 jax.vmap() 实现此功能,因此 func1d 必须与 vmap 兼容。

参数:
  • func1d (Callable) – 一个可调用函数,签名如下 func1d(arr, /, *args, **kwargs),其中 *args**kwargs 是传递给 apply_along_axis() 的附加位置和关键字参数。

  • axis (int) – 应用函数的整数轴。

  • arr (ArrayLike) – 应用函数的数组。

  • args – 传递给 func1d 的附加位置和关键字参数。

  • kwargs – 传递给 func1d 的附加位置和关键字参数。

返回:

沿指定轴应用 func1d 的结果。

返回类型:

Array

另请参阅

示例

一个二维的简单示例,其中函数按行或按列应用

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> def func1d(x):
...   return jnp.sum(x ** 2)
>>> jnp.apply_along_axis(func1d, 0, x)
Array([17, 29, 45], dtype=int32)
>>> jnp.apply_along_axis(func1d, 1, x)
Array([14, 77], dtype=int32)

对于 2D 输入,可以使用 jax.vmap() 等效地表达,但请注意,vmap 指定的是映射轴而不是应用轴

>>> jax.vmap(func1d, in_axes=1)(x)  # same as applying along axis 0
Array([17, 29, 45], dtype=int32)
>>> jax.vmap(func1d, in_axes=0)(x)  # same as applying along axis 1
Array([14, 77], dtype=int32)

对于 3D 输入,apply_along_axis() 等价于在两个维度上映射

>>> x_3d = jnp.arange(24).reshape(2, 3, 4)
>>> jnp.apply_along_axis(func1d, 2, x_3d)
Array([[  14,  126,  366],
       [ 734, 1230, 1854]], dtype=int32)
>>> jax.vmap(jax.vmap(func1d))(x_3d)
Array([[  14,  126,  366],
       [ 734, 1230, 1854]], dtype=int32)

应用的函数还可以接受任意位置或关键字参数,这些参数应直接作为额外的参数传递给 apply_along_axis()

>>> def func1d(x, exponent):
...   return jnp.sum(x ** exponent)
>>> jnp.apply_along_axis(func1d, 0, x, exponent=3)
Array([ 65, 133, 243], dtype=int32)