jax.numpy.linalg.matrix_norm#
- jax.numpy.linalg.matrix_norm(x, /, *, keepdims=False, ord='fro')[源代码]#
计算矩阵或矩阵堆栈的范数。
JAX 实现的
numpy.linalg.matrix_norm()
- 参数:
x (类数组) – 用于计算范数的形状为
(..., M, N)
的数组。keepdims (bool) – 如果为 True,则在输出中保留缩减的维度。
ord (str | int) – 一个字符串或整数,指定范数的类型;默认为 Frobenius 范数。 有关可用选项的详细信息,请参阅
numpy.linalg.norm()
。
- 返回值:
包含
x
的范数的数组。 如果keepdims
为 False,则形状为x.shape[:-2]
,如果keepdims
为 True,则形状为(..., 1, 1)
。- 返回类型:
另请参阅
jax.numpy.linalg.vector_norm()
: 向量或向量堆栈的范数。jax.numpy.linalg.norm()
: 更通用的矩阵或向量范数。
示例
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6], ... [7, 8, 9]]) >>> jnp.linalg.matrix_norm(x) Array(16.881943, dtype=float32)