jax.numpy.resize#

jax.numpy.resize(a, new_shape)[source]#

返回具有指定形状的新数组。

JAX实现的 numpy.resize()

参数:
  • a (类数组对象) – 输入数组或标量。

  • new_shape (形状) – 整数或整数元组。指定调整大小后的数组的形状。

返回:

一个具有指定形状的调整大小后的数组。如果调整大小后的数组大于原始数组,则 a 的元素将在调整大小后的数组中重复。

返回类型:

数组

另请参阅

示例

>>> x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> jnp.resize(x, (3, 3))
Array([[1, 2, 3],
       [4, 5, 6],
       [7, 8, 9]], dtype=int32)
>>> jnp.resize(x, (3, 4))
Array([[1, 2, 3, 4],
       [5, 6, 7, 8],
       [9, 1, 2, 3]], dtype=int32)
>>> jnp.resize(4, (3, 2))
Array([[4, 4],
       [4, 4],
       [4, 4]], dtype=int32, weak_type=True)