jax.numpy.resize#
- jax.numpy.resize(a, new_shape)[source]#
返回一个具有指定形状的新数组。
JAX 实现
numpy.resize()
.- 参数:
a (ArrayLike) – 输入数组或标量。
new_shape (Shape) – 整数或整数元组。指定调整大小后的数组的形状。
- 返回值:
一个具有指定形状的调整大小后的数组。如果调整大小后的数组大于原始数组,则 `a` 的元素将在调整大小后的数组中重复。
- 返回类型:
参见
jax.numpy.reshape()
: 返回数组的重塑副本。jax.numpy.repeat()
: 使用重复元素构建数组。
示例
>>> x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9]) >>> jnp.resize(x, (3, 3)) Array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=int32) >>> jnp.resize(x, (3, 4)) Array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 1, 2, 3]], dtype=int32) >>> jnp.resize(4, (3, 2)) Array([[4, 4], [4, 4], [4, 4]], dtype=int32, weak_type=True)