jax.numpy.linalg.vector_norm

内容

jax.numpy.linalg.vector_norm#

jax.numpy.linalg.vector_norm(x, /, *, axis=None, keepdims=False, ord=2)[source]#

计算向量或向量批次的向量范数。

JAX 实现 numpy.linalg.vector_norm().

参数:
  • x (ArrayLike) – 要获取范数的 N 维数组。

  • axis (int | None | None) – 可选的轴,用于计算向量范数。如果为 None(默认),则 x 会被扁平化,并且范数将针对所有值计算。

  • keepdims (bool) – 如果为 True,则在输出中保留减少的维度。

  • ord (int | str) – 指定范数类型的字符串或整数;默认为 2 范数。有关可用选项的详细信息,请参见 numpy.linalg.norm().

返回值:

包含 x 范数的数组。

返回类型:

Array

另请参阅

示例

单个向量的范数

>>> x = jnp.array([1., 2., 3.])
>>> jnp.linalg.vector_norm(x)
Array(3.7416575, dtype=float32)

一批向量的范数

>>> x = jnp.array([[1., 2., 3.],
...                [4., 5., 7.]])
>>> jnp.linalg.vector_norm(x, axis=1)
Array([3.7416575, 9.486833 ], dtype=float32)