jax.numpy.insert#

jax.numpy.insert(arr, obj, values, axis=None)[源代码]#

在指定索引处将条目插入到数组中。

numpy.insert() 的 JAX 实现。

参数:
  • arr (ArrayLike) – 将在其中插入值的数组对象。

  • obj (ArrayLike | slice) – 指定插入位置的切片或索引数组。

  • values (ArrayLike) – 要插入的值的数组。

  • axis (int | None | None) – 指定多维数组情况下的插入轴。如果未指定,则 arr 将被展平。

返回:

一个 arr 的副本,其中值已插入到指定位置。

返回类型:

数组

另请参阅

示例

插入单个值

>>> x = jnp.arange(5)
>>> jnp.insert(x, 2, 99)
Array([ 0,  1, 99,  2,  3,  4], dtype=int32)

使用切片插入多个相同的值

>>> jnp.insert(x, slice(None, None, 2), -1)
Array([-1,  0,  1, -1,  2,  3, -1,  4], dtype=int32)

使用索引插入多个值

>>> indices = jnp.array([4, 2, 5])
>>> values = jnp.array([10, 11, 12])
>>> jnp.insert(x, indices, values)
Array([ 0,  1, 11,  2,  3, 10,  4, 12], dtype=int32)

将列插入到二维数组中

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> indices = jnp.array([1, 3])
>>> values = jnp.array([[10, 11],
...                     [12, 13]])
>>> jnp.insert(x, indices, values, axis=1)
Array([[ 1, 10,  2,  3, 11],
       [ 4, 12,  5,  6, 13]], dtype=int32)