jax.numpy.linalg.norm#
- jax.numpy.linalg.norm(x, ord=None, axis=None, keepdims=False)[源代码]#
计算矩阵或向量的范数。
numpy.linalg.norm()
的 JAX 实现。- 参数:
- 返回:
包含 x 的指定范数的数组。
- 返回类型:
注意事项
计算的范数的类型取决于
ord
的值以及被缩减的轴的数量。对于 向量范数(即单个轴缩减)
ord=None
(默认) 计算 2-范数ord=inf
计算max(abs(x))
ord=-inf
计算min(abs(x))
ord=0
计算sum(x!=0)
对于其他数值,计算
sum(abs(x) ** ord)**(1/ord)
对于 矩阵范数(即两个轴缩减)
ord='fro'
或ord=None
(默认) 计算 Frobenius 范数ord='nuc'
计算核范数,即奇异值的总和ord=1
计算max(abs(x).sum(0))
ord=-1
计算min(abs(x).sum(0))
ord=2
计算 2-范数,即最大的奇异值ord=-2
计算最小的奇异值
示例
向量范数
>>> x = jnp.array([3., 4., 12.]) >>> jnp.linalg.norm(x) Array(13., dtype=float32) >>> jnp.linalg.norm(x, ord=1) Array(19., dtype=float32) >>> jnp.linalg.norm(x, ord=0) Array(3., dtype=float32)
矩阵范数
>>> x = jnp.array([[1., 2., 3.], ... [4., 5., 7.]]) >>> jnp.linalg.norm(x) # Frobenius norm Array(10.198039, dtype=float32) >>> jnp.linalg.norm(x, ord='nuc') # nuclear norm Array(10.762535, dtype=float32) >>> jnp.linalg.norm(x, ord=1) # 1-norm Array(10., dtype=float32)
批量向量范数
>>> jnp.linalg.norm(x, axis=1) Array([3.7416575, 9.486833 ], dtype=float32)