jax.numpy.atleast_3d

内容

jax.numpy.atleast_3d#

jax.numpy.atleast_3d() list[Array][source]#
jax.numpy.atleast_3d(x: ArrayLike, /) Array
jax.numpy.atleast_3d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) list[Array]

将输入转换为至少具有 3 个维度的数组。

JAX 实现的 numpy.atleast_3d().

参数:

arguments. (更多数组)

返回值:

与输入值相对应的数组或数组列表。 形状为 () 的数组被转换为形状 (1, 1, 1),形状为 (N,) 的一维数组被转换为形状 (1, N, 1),形状为 (M, N) 的二维数组被转换为形状 (M, N, 1),所有其他形状的数组保持不变。

示例

标量参数被转换为 3D,大小为 1 的数组

>>> x = jnp.float32(1.0)
>>> jnp.atleast_3d(x)
Array([[[1.]]], dtype=float32)

一维数组有一个单元维度前置和追加

>>> y = jnp.arange(4)
>>> jnp.atleast_3d(y).shape
(1, 4, 1)

二维数组有一个单元维度追加

>>> z = jnp.ones((2, 3))
>>> jnp.atleast_3d(z).shape
(2, 3, 1)

可以一次将多个参数传递给函数,在这种情况下,将返回结果列表

>>> x3, y3 = jnp.atleast_3d(x, y)
>>> print(x3)
[[[1.]]]
>>> print(y3)
[[[0]
  [1]
  [2]
  [3]]]