jax.numpy.identity

内容

jax.numpy.identity#

jax.numpy.identity(n, dtype=None)[source]#

创建一个方阵单位矩阵

JAX 实现 numpy.identity().

参数:
  • n (DimSize) – 指定每个数组维度的尺寸的整数。

  • dtype (DTypeLike | None | None) – 可选的 dtype;默认为浮点数。

返回值:

形状为 (n, n) 的单位矩阵。

返回类型:

数组

参见

jax.numpy.eye(): 非方阵和/或偏移单位矩阵。

示例

一个简单的 3x3 单位矩阵

>>> jnp.identity(3)
Array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)

一个 2x2 整数单位矩阵

>>> jnp.identity(2, dtype=int)
Array([[1, 0],
       [0, 1]], dtype=int32)