jax.numpy.place#

jax.numpy.place(arr, mask, vals, *, inplace=True)[源代码]#

根据掩码更新数组元素。

JAX 实现的 numpy.place()

numpy.place() 的语义是原地修改数组,这对于 JAX 的不可变数组来说是不可能的。 JAX 版本返回输入修改后的副本,并添加了 inplace 参数,用户必须将其设置为 False,以提醒此 API 的差异。

参数:
  • arr (ArrayLike) – 将值放入的数组。

  • mask (ArrayLike) – 与 arr 大小相同的布尔掩码。

  • vals (ArrayLike) – 将插入到 arr 中由掩码指示位置的值。 如果提供的值过多,它们将被截断。 如果提供的值不足,它们将被重复。

  • inplace (bool) – 必须设置为 False 以指示输入不是原地修改的,而是返回修改后的副本。

返回:

arr 的副本,其中掩码的值设置为来自 vals 的条目。

返回类型:

数组

另请参阅

示例

>>> x = jnp.zeros((3, 5), dtype=int)
>>> mask = (jnp.arange(x.size) % 3 == 0).reshape(x.shape)
>>> mask
Array([[ True, False, False,  True, False],
       [False,  True, False, False,  True],
       [False, False,  True, False, False]], dtype=bool)

放置标量值

>>> jnp.place(x, mask, 1, inplace=False)
Array([[1, 0, 0, 1, 0],
       [0, 1, 0, 0, 1],
       [0, 0, 1, 0, 0]], dtype=int32)

在这种情况下,jnp.place 类似于掩码数组更新语法

>>> x.at[mask].set(1)
Array([[1, 0, 0, 1, 0],
       [0, 1, 0, 0, 1],
       [0, 0, 1, 0, 0]], dtype=int32)

从数组放置值时,place 有所不同。 数组被重复以填充掩码的条目

>>> vals = jnp.array([1, 3, 5])
>>> jnp.place(x, mask, vals, inplace=False)
Array([[1, 0, 0, 3, 0],
       [0, 5, 0, 0, 1],
       [0, 0, 3, 0, 0]], dtype=int32)