jax.scipy.ndimage.map_coordinates

jax.scipy.ndimage.map_coordinates#

jax.scipy.ndimage.map_coordinates(input, coordinates, order, mode='constant', cval=0.0)[source]#

使用插值将输入数组映射到新的坐标。

JAX 实现 scipy.ndimage.map_coordinates()

给定一个输入数组和一组坐标,此函数返回输入数组在这些坐标处的插值值。

参数:
  • 输入 (Array | ndarray | bool | number | bool | int | float | complex) – 用于插值的 N 维输入数组。

  • 坐标 (Sequence[Array | ndarray | bool | number | bool | int | float | complex]) – 长度为 N 的序列,指定要评估插值值的坐标。

  • 阶数 (int) –

    插值的阶数。JAX 支持以下阶数

    • 0: 最近邻

    • 1: 线性

  • 模式 (str) – 输入边界外的点将根据给定的模式填充。JAX 支持以下模式之一:('constant', 'nearest', 'mirror', 'wrap', 'reflect')。请注意,JAX 中的 'wrap' 模式在 SciPy 中的行为与 'grid-wrap' 模式相同,而 JAX 中的 'constant' 模式在 SciPy 中的行为与 'grid-constant' 模式相同。这种差异是由 SciPy 中这些模式以前存在的错误 (scipy/scipy#2640) 引起的,JAX 首先通过改变现有模式的行为来修复了这个错误,后来 SciPy 通过添加新名称的模式来修复了这个错误,而不是修复现有的模式,这是为了向后兼容的原因。默认值为 'constant'。

  • cval (Array | ndarray | bool | number | bool | int | float | complex) – 如果 mode='constant',则用于输入边界外点的值。默认值为 0.0。

返回值:

指定坐标处插值的数值。

示例

>>> input = jnp.arange(12.0).reshape(3, 4)
>>> input
Array([[ 0.,  1.,  2.,  3.],
       [ 4.,  5.,  6.,  7.],
       [ 8.,  9., 10., 11.]], dtype=float32)
>>> coordinates = [jnp.array([0.5, 1.5]),
...                jnp.array([1.5, 2.5])]
>>> jax.scipy.ndimage.map_coordinates(input, coordinates, order=1)
Array([3.5, 8.5], dtype=float32)

注意

边界附近的插值与 scipy 函数不同,因为 JAX 修复了一个未解决的错误;请参见 google/jax#11097。此函数将 mode 参数解释为 SciPy 文档中所述,但并非 SciPy 中的实现方式。