jax.numpy.linalg.vector_norm#
- jax.numpy.linalg.vector_norm(x, /, *, axis=None, keepdims=False, ord=2)[source]#
计算向量或向量批次的向量范数。
JAX 实现
numpy.linalg.vector_norm()
.- 参数:
x (ArrayLike) – 要获取范数的 N 维数组。
axis (int | None | None) – 可选的轴,用于计算向量范数。如果为 None(默认),则
x
会被扁平化,并且范数将针对所有值计算。keepdims (bool) – 如果为 True,则在输出中保留减少的维度。
ord (int | str) – 指定范数类型的字符串或整数;默认为 2 范数。有关可用选项的详细信息,请参见
numpy.linalg.norm()
.
- 返回值:
包含
x
范数的数组。- 返回类型:
另请参阅
jax.numpy.linalg.matrix_norm()
: 矩阵或矩阵堆栈的范数。jax.numpy.linalg.norm()
: 更通用的矩阵或向量范数。
示例
单个向量的范数
>>> x = jnp.array([1., 2., 3.]) >>> jnp.linalg.vector_norm(x) Array(3.7416575, dtype=float32)
一批向量的范数
>>> x = jnp.array([[1., 2., 3.], ... [4., 5., 7.]]) >>> jnp.linalg.vector_norm(x, axis=1) Array([3.7416575, 9.486833 ], dtype=float32)