jax.lax.dynamic_update_index_in_dim#
- jax.lax.dynamic_update_index_in_dim(operand, update, index, axis)[source]#
围绕
dynamic_update_slice()
的便捷包装器,用于在单个axis
中更新大小为 1 的切片。- 参数:
- 返回值:
更新后的数组
- 返回类型:
示例
>>> x = jnp.zeros(6) >>> y = 1.0 >>> dynamic_update_index_in_dim(x, y, 2, axis=0) Array([0., 0., 1., 0., 0., 0.], dtype=float32)
>>> y = jnp.array([1.0]) >>> dynamic_update_index_in_dim(x, y, 2, axis=0) Array([0., 0., 1., 0., 0., 0.], dtype=float32)
如果指定的索引超出范围,则索引将被剪裁到有效范围
>>> dynamic_update_index_in_dim(x, y, 10, axis=0) Array([0., 0., 0., 0., 0., 1.], dtype=float32)
这是一个二维动态索引更新的示例
>>> x = jnp.zeros((4, 4)) >>> y = jnp.ones(4) >>> dynamic_update_index_in_dim(x, y, 1, axis=0) Array([[0., 0., 0., 0.], [1., 1., 1., 1.], [0., 0., 0., 0.], [0., 0., 0., 0.]], dtype=float32)
请注意,
update
中其他轴的形状不必与operand
的关联维度匹配>>> y = jnp.ones((1, 3)) >>> dynamic_update_index_in_dim(x, y, 1, 0) Array([[0., 0., 0., 0.], [1., 1., 1., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]], dtype=float32)