jax.numpy.delete#
- jax.numpy.delete(arr, obj, axis=None, *, assume_unique_indices=False)[source]#
从数组中删除条目。
JAX 实现的
numpy.delete()
。- 参数:
- 返回:
删除了指定索引的
arr
的副本。- 返回类型:
注意
delete()
通常需要索引规范是静态的。如果索引是保证包含唯一条目的整数数组,则可以指定assume_unique_indices=True
以便以不需要静态索引的方式执行操作。另请参阅
jax.numpy.insert()
:将条目插入数组。
示例
从一维数组中删除条目
>>> a = jnp.array([4, 5, 6, 7, 8, 9]) >>> jnp.delete(a, 2) Array([4, 5, 7, 8, 9], dtype=int32) >>> jnp.delete(a, slice(1, 4)) # delete a[1:4] Array([4, 8, 9], dtype=int32) >>> jnp.delete(a, slice(None, None, 2)) # delete a[::2] Array([5, 7, 9], dtype=int32)
沿指定轴从二维数组中删除条目
>>> a2 = jnp.array([[4, 5, 6], ... [7, 8, 9]]) >>> jnp.delete(a2, 1, axis=1) Array([[4, 6], [7, 9]], dtype=int32)
通过索引序列删除多个条目
>>> indices = jnp.array([0, 1, 3]) >>> jnp.delete(a, indices) Array([6, 8, 9], dtype=int32)
由于输出形状在可能存在重复索引的情况下是未知的,因此这在
jit()
和其他变换下会失败>>> jax.jit(jnp.delete)(a, indices) Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[3].
如果可以确保索引是唯一的,则传递
assume_unique_indices
以允许在 JIT 下执行此操作>>> jit_delete = jax.jit(jnp.delete, static_argnames=['assume_unique_indices']) >>> jit_delete(a, indices, assume_unique_indices=True) Array([6, 8, 9], dtype=int32)