jax.numpy.insert#
- jax.numpy.insert(arr, obj, values, axis=None)[源代码]#
在指定索引处将条目插入到数组中。
numpy.insert()
的 JAX 实现。- 参数:
- 返回:
一个
arr
的副本,其中值已插入到指定位置。- 返回类型:
另请参阅
jax.numpy.delete()
: 从数组中删除条目。
示例
插入单个值
>>> x = jnp.arange(5) >>> jnp.insert(x, 2, 99) Array([ 0, 1, 99, 2, 3, 4], dtype=int32)
使用切片插入多个相同的值
>>> jnp.insert(x, slice(None, None, 2), -1) Array([-1, 0, 1, -1, 2, 3, -1, 4], dtype=int32)
使用索引插入多个值
>>> indices = jnp.array([4, 2, 5]) >>> values = jnp.array([10, 11, 12]) >>> jnp.insert(x, indices, values) Array([ 0, 1, 11, 2, 3, 10, 4, 12], dtype=int32)
将列插入到二维数组中
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> indices = jnp.array([1, 3]) >>> values = jnp.array([[10, 11], ... [12, 13]]) >>> jnp.insert(x, indices, values, axis=1) Array([[ 1, 10, 2, 3, 11], [ 4, 12, 5, 6, 13]], dtype=int32)