jax.numpy.atleast_3d#
- jax.numpy.atleast_3d() list[Array] [source]#
- jax.numpy.atleast_3d(x: ArrayLike, /) Array
- jax.numpy.atleast_3d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) list[Array]
将输入转换为至少具有 3 个维度的数组。
JAX 实现的
numpy.atleast_3d()
.- 参数:
arguments. (零 或 更多数组)
- 返回值:
与输入值相对应的数组或数组列表。 形状为
()
的数组被转换为形状(1, 1, 1)
,形状为(N,)
的一维数组被转换为形状(1, N, 1)
,形状为(M, N)
的二维数组被转换为形状(M, N, 1)
,所有其他形状的数组保持不变。
示例
标量参数被转换为 3D,大小为 1 的数组
>>> x = jnp.float32(1.0) >>> jnp.atleast_3d(x) Array([[[1.]]], dtype=float32)
一维数组有一个单元维度前置和追加
>>> y = jnp.arange(4) >>> jnp.atleast_3d(y).shape (1, 4, 1)
二维数组有一个单元维度追加
>>> z = jnp.ones((2, 3)) >>> jnp.atleast_3d(z).shape (2, 3, 1)
可以一次将多个参数传递给函数,在这种情况下,将返回结果列表
>>> x3, y3 = jnp.atleast_3d(x, y) >>> print(x3) [[[1.]]] >>> print(y3) [[[0] [1] [2] [3]]]