jax.numpy.eye#

jax.numpy.eye(N, M=None, k=0, dtype=None, *, device=None)[源代码]#

创建一个方形或矩形单位矩阵

numpy.eye() 的 JAX 实现。

参数:
  • N (DimSize) – 指定数组第一个维度的整数。

  • M (DimSize | None | None) – 可选整数,指定数组的第二个维度;默认为与 N 相同的值。

  • k (int | ArrayLike) – 可选整数,指定对角线的偏移量。正值用于上对角线,负值用于下对角线。默认为零。

  • dtype (DTypeLike | None | None) – 可选 dtype;默认为浮点数。

  • device (xc.Device | Sharding | None | None) – 可选 DeviceSharding,用于将创建的数组提交到其中。

返回:

形状为 (N, M) 的单位数组,或者如果未指定 M,则为 (N, N)

返回类型:

数组

另请参阅

jax.numpy.identity(): 用于生成方形单位矩阵的更简单的 API。

示例

一个简单的 3x3 单位矩阵

>>> jnp.eye(3)
Array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)

带有偏移对角线的整数单位矩阵

>>> jnp.eye(3, k=1, dtype=int)
Array([[0, 1, 0],
       [0, 0, 1],
       [0, 0, 0]], dtype=int32)
>>> jnp.eye(3, k=-1, dtype=int)
Array([[0, 0, 0],
       [1, 0, 0],
       [0, 1, 0]], dtype=int32)

非方形单位矩阵

>>> jnp.eye(3, 5, k=1)
Array([[0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.]], dtype=float32)