jax.numpy.atleast_1d

内容

jax.numpy.atleast_1d#

jax.numpy.atleast_1d() list[Array][source]#
jax.numpy.atleast_1d(x: ArrayLike, /) Array
jax.numpy.atleast_1d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) list[Array]

将输入转换为至少具有 1 维的数组。

JAX 实现 numpy.atleast_1d().

参数:

arguments. (零个多个类数组)

返回值:

与输入值相对应的数组或数组列表。 形状为 () 的数组将转换为形状为 (1,) 的数组,其他形状的数组将保持不变。

示例

标量参数将转换为 1D,长度为 1 的数组

>>> x = jnp.float32(1.0)
>>> jnp.atleast_1d(x)
Array([1.], dtype=float32)

更高维的输入将保持不变

>>> y = jnp.arange(4)
>>> jnp.atleast_1d(y)
Array([0, 1, 2, 3], dtype=int32)

可以一次性将多个参数传递给函数,在这种情况下,将返回结果列表

>>> jnp.atleast_1d(x, y)
[Array([1.], dtype=float32), Array([0, 1, 2, 3], dtype=int32)]