jax.numpy.resize#
- jax.numpy.resize(a, new_shape)[源代码]#
返回具有指定形状的新数组。
JAX 实现的
numpy.resize()
。- 参数:
a (ArrayLike) – 输入数组或标量。
new_shape (Shape) – int 或 int 元组。指定调整大小后的数组的形状。
- 返回:
具有指定形状的调整大小后的数组。如果调整大小后的数组大于原始数组,则调整大小后的数组中会重复
a
的元素。- 返回类型:
另请参阅
jax.numpy.reshape()
: 返回数组的重塑副本。jax.numpy.repeat()
: 从重复的元素构造数组。
示例
>>> x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9]) >>> jnp.resize(x, (3, 3)) Array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=int32) >>> jnp.resize(x, (3, 4)) Array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 1, 2, 3]], dtype=int32) >>> jnp.resize(4, (3, 2)) Array([[4, 4], [4, 4], [4, 4]], dtype=int32, weak_type=True)