jax.numpy.place#
- jax.numpy.place(arr, mask, vals, *, inplace=True)[source]#
根据掩码更新数组元素。
JAX 实现
numpy.place()
.The semantics of
numpy.place()
are to modify arrays in-place, which is not possible for JAX’s immutable arrays. The JAX version returns a modified copy of the input, and adds theinplace
parameter which must be set to False` by the user as a reminder of this API difference.- 参数:
arr (ArrayLike) – 要放置值的数组。
mask (ArrayLike) – 与
arr
大小相同的布尔掩码。vals (ArrayLike) – 要插入
arr
中的值,位置由掩码指示。如果提供的数值过多,将被截断。如果提供的数值不足,将被重复。inplace (bool) – 必须设置为 False,表示输入不会就地修改,而是返回一个修改后的副本。
- 返回值:
arr
的副本,其中掩码值设置为来自 vals 的条目。- 返回类型:
另请参阅
jax.numpy.put()
: 将元素放入数组中的数字索引处。jax.numpy.ndarray.at()
: 使用 NumPy 风格的索引进行数组更新
示例
>>> x = jnp.zeros((3, 5), dtype=int) >>> mask = (jnp.arange(x.size) % 3 == 0).reshape(x.shape) >>> mask Array([[ True, False, False, True, False], [False, True, False, False, True], [False, False, True, False, False]], dtype=bool)
放置标量值
>>> jnp.place(x, mask, 1, inplace=False) Array([[1, 0, 0, 1, 0], [0, 1, 0, 0, 1], [0, 0, 1, 0, 0]], dtype=int32)
在这种情况下,
jnp.place
类似于掩码数组更新语法>>> x.at[mask].set(1) Array([[1, 0, 0, 1, 0], [0, 1, 0, 0, 1], [0, 0, 1, 0, 0]], dtype=int32)
place
在放置来自数组的值时有所不同。该数组将被重复以填充掩码条目>>> vals = jnp.array([1, 3, 5]) >>> jnp.place(x, mask, vals, inplace=False) Array([[1, 0, 0, 3, 0], [0, 5, 0, 0, 1], [0, 0, 3, 0, 0]], dtype=int32)