jax.numpy.linalg.norm#

jax.numpy.linalg.norm(x, ord=None, axis=None, keepdims=False)[源代码]#

计算矩阵或向量的范数。

numpy.linalg.norm() 的 JAX 实现。

参数:
  • x (ArrayLike) – 用于计算范数的 N 维数组。

  • ord (int | str | None) – 指定要采用的范数类型。默认情况下,矩阵使用弗罗贝尼乌斯范数,向量使用 2-范数。有关其他选项,请参见下面的“备注”。

  • axis (None | tuple[int, ...] | int) – 指定计算范数的轴的整数或整数序列。默认为 x 的所有轴。

  • keepdims (bool) – 如果为 True,则输出数组将具有与输入相同的维度数,被缩减轴的大小将替换为 1(默认值:False)。

返回:

包含 x 的指定范数的数组。

返回类型:

数组

注意事项

计算的范数的类型取决于 ord 的值以及被缩减的轴的数量。

对于 向量范数(即单个轴缩减)

  • ord=None (默认) 计算 2-范数

  • ord=inf 计算 max(abs(x))

  • ord=-inf 计算 min(abs(x))

  • ord=0 计算 sum(x!=0)

  • 对于其他数值,计算 sum(abs(x) ** ord)**(1/ord)

对于 矩阵范数(即两个轴缩减)

  • ord='fro'ord=None (默认) 计算 Frobenius 范数

  • ord='nuc' 计算核范数,即奇异值的总和

  • ord=1 计算 max(abs(x).sum(0))

  • ord=-1 计算 min(abs(x).sum(0))

  • ord=2 计算 2-范数,即最大的奇异值

  • ord=-2 计算最小的奇异值

示例

向量范数

>>> x = jnp.array([3., 4., 12.])
>>> jnp.linalg.norm(x)
Array(13., dtype=float32)
>>> jnp.linalg.norm(x, ord=1)
Array(19., dtype=float32)
>>> jnp.linalg.norm(x, ord=0)
Array(3., dtype=float32)

矩阵范数

>>> x = jnp.array([[1., 2., 3.],
...                [4., 5., 7.]])
>>> jnp.linalg.norm(x)  # Frobenius norm
Array(10.198039, dtype=float32)
>>> jnp.linalg.norm(x, ord='nuc')  # nuclear norm
Array(10.762535, dtype=float32)
>>> jnp.linalg.norm(x, ord=1)  # 1-norm
Array(10., dtype=float32)

批量向量范数

>>> jnp.linalg.norm(x, axis=1)
Array([3.7416575, 9.486833 ], dtype=float32)