jax.numpy.delete#

jax.numpy.delete(arr, obj, axis=None, *, assume_unique_indices=False)[源代码]#

从数组中删除一个或多个条目。

numpy.delete() 的 JAX 实现。

参数
  • arr (ArrayLike) – 要从中删除条目的数组。

  • obj (ArrayLike | slice) – 要删除的索引、多个索引或切片。

  • axis (int | None | None) – 要沿其删除条目的轴。

  • assume_unique_indices (bool) – 在使用类似数组的整数(非布尔值)索引的情况下,假设索引是唯一的,并以与 JIT 和其他 JAX 转换兼容的方式执行删除。

返回

删除了指定索引的 arr 的副本。

返回类型

Array

注意

delete() 通常要求索引规范是静态的。如果索引是一个保证包含唯一条目的整数数组,则可以指定 assume_unique_indices=True 以一种不需要静态索引的方式执行操作。

另请参阅

示例

从一维数组中删除条目

>>> a = jnp.array([4, 5, 6, 7, 8, 9])
>>> jnp.delete(a, 2)
Array([4, 5, 7, 8, 9], dtype=int32)
>>> jnp.delete(a, slice(1, 4))  # delete a[1:4]
Array([4, 8, 9], dtype=int32)
>>> jnp.delete(a, slice(None, None, 2))  # delete a[::2]
Array([5, 7, 9], dtype=int32)

沿指定轴从二维数组中删除条目

>>> a2 = jnp.array([[4, 5, 6],
...                 [7, 8, 9]])
>>> jnp.delete(a2, 1, axis=1)
Array([[4, 6],
       [7, 9]], dtype=int32)

通过索引序列删除多个条目

>>> indices = jnp.array([0, 1, 3])
>>> jnp.delete(a, indices)
Array([6, 8, 9], dtype=int32)

jit() 和其他转换下,这将失败,因为在存在重复索引的情况下,无法知道输出形状

>>> jax.jit(jnp.delete)(a, indices)  
Traceback (most recent call last):
  ...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[3].

如果您可以确保索引是唯一的,请传递 assume_unique_indices 以便在 JIT 下执行此操作

>>> jit_delete = jax.jit(jnp.delete, static_argnames=['assume_unique_indices'])
>>> jit_delete(a, indices, assume_unique_indices=True)
Array([6, 8, 9], dtype=int32)