jax.numpy.linalg.matrix_norm

内容

jax.numpy.linalg.matrix_norm#

jax.numpy.linalg.matrix_norm(x, /, *, keepdims=False, ord='fro')[source]#

计算矩阵或矩阵堆栈的范数。

JAX 实现 numpy.linalg.matrix_norm()

参数:
  • x (ArrayLike) – 形状为 (..., M, N) 的数组,用于取范数。

  • keepdims (bool) – 如果为 True,则在输出中保留缩减的维度。

  • ord (str) – 指定范数类型的字符串或整数;默认为 Frobenius 范数。有关可用选项的详细信息,请参阅 numpy.linalg.norm()

返回值:

包含 x 范数的数组。如果 keepdims 为 False,则形状为 x.shape[:-2],如果 keepdims 为 True,则形状为 (..., 1, 1)

返回类型:

数组

另请参阅

示例

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6],
...                [7, 8, 9]])
>>> jnp.linalg.matrix_norm(x)
Array(16.881943, dtype=float32)