jax.numpy.resize#
- jax.numpy.resize(a, new_shape)[source]#
返回具有指定形状的新数组。
JAX实现的
numpy.resize()
。- 参数:
a (类数组对象) – 输入数组或标量。
new_shape (形状) – 整数或整数元组。指定调整大小后的数组的形状。
- 返回:
一个具有指定形状的调整大小后的数组。如果调整大小后的数组大于原始数组,则
a
的元素将在调整大小后的数组中重复。- 返回类型:
另请参阅
jax.numpy.reshape()
:返回数组的重塑副本。jax.numpy.repeat()
:从重复的元素构造数组。
示例
>>> x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9]) >>> jnp.resize(x, (3, 3)) Array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=int32) >>> jnp.resize(x, (3, 4)) Array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 1, 2, 3]], dtype=int32) >>> jnp.resize(4, (3, 2)) Array([[4, 4], [4, 4], [4, 4]], dtype=int32, weak_type=True)