jax.numpy.atleast_1d#
- jax.numpy.atleast_1d(*arys)[源代码]#
将输入转换为至少具有 1 维的数组。
numpy.atleast_1d()
的 JAX 实现。- 参数:
arguments. (零个或多个类数组)
arys (类数组)
- 返回值:
对应于输入值的数组或数组列表。形状为
()
的数组将转换为形状(1,)
,其他形状的数组保持不变。- 返回类型:
示例
标量参数被转换为 1D,长度为 1 的数组
>>> x = jnp.float32(1.0) >>> jnp.atleast_1d(x) Array([1.], dtype=float32)
更高维度的输入保持不变
>>> y = jnp.arange(4) >>> jnp.atleast_1d(y) Array([0, 1, 2, 3], dtype=int32)
可以一次性将多个参数传递给函数,在这种情况下,将返回一个结果列表。
>>> jnp.atleast_1d(x, y) [Array([1.], dtype=float32), Array([0, 1, 2, 3], dtype=int32)]