jax.numpy.linalg.norm

内容

jax.numpy.linalg.norm#

jax.numpy.linalg.norm(x, ord=None, axis=None, keepdims=False)[source]#

计算矩阵或向量的范数。

JAX 实现 numpy.linalg.norm().

参数::
  • x (ArrayLike) – 将计算范数的 N 维数组。

  • ord (int | str | None) – 指定要取的范数类型。默认值为矩阵的 Frobenius 范数,向量为 2 范数。有关其他选项,请参见下面的说明。

  • axis ( | 元组[整数, ...] | 整数) – 指定计算范数的轴的整数或整数序列。默认情况下为 x 的所有轴。

  • keepdims (布尔值) – 如果为 True,输出数组将与输入具有相同数量的维度,缩减的轴的大小将替换为 1(默认值:False)。

返回值:

包含 x 的指定范数的数组。

返回类型:

数组

注意

计算的范数类型取决于 ord 的值和被缩减的轴的数量。

对于向量范数(即单个轴缩减)

  • ord=None(默认)计算 2-范数

  • ord=inf 计算 max(abs(x))

  • ord=-inf 计算 min(abs(x))``

  • ord=0 计算 sum(x!=0)

  • 对于其他数值,计算 sum(abs(x) ** ord)**(1/ord)

对于矩阵范数(即两个轴缩减)

  • ord='fro'ord=None(默认)计算 Frobenius 范数

  • ord='nuc' 计算核范数,或奇异值的总和

  • ord=1 计算 max(abs(x).sum(0))

  • ord=-1 计算 min(abs(x).sum(0))

  • ord=2 计算 2-范数,即最大的奇异值

  • ord=-2 计算最小的奇异值

例子

向量范数

>>> x = jnp.array([3., 4., 12.])
>>> jnp.linalg.norm(x)
Array(13., dtype=float32)
>>> jnp.linalg.norm(x, ord=1)
Array(19., dtype=float32)
>>> jnp.linalg.norm(x, ord=0)
Array(3., dtype=float32)

矩阵范数

>>> x = jnp.array([[1., 2., 3.],
...                [4., 5., 7.]])
>>> jnp.linalg.norm(x)  # Frobenius norm
Array(10.198039, dtype=float32)
>>> jnp.linalg.norm(x, ord='nuc')  # nuclear norm
Array(10.762535, dtype=float32)
>>> jnp.linalg.norm(x, ord=1)  # 1-norm
Array(10., dtype=float32)

批向量范数

>>> jnp.linalg.norm(x, axis=1)
Array([3.7416575, 9.486833 ], dtype=float32)