jax.numpy.eye

内容

jax.numpy.eye#

jax.numpy.eye(N, M=None, k=0, dtype=None, *, device=None)[source]#

创建方形或矩形的单位矩阵

JAX 的 numpy.eye() 实现。

参数:
  • N (DimSize) – 指定数组第一维的整数。

  • M (DimSize | None | None) – 指定数组第二维的可选整数;默认为与 N 相同的值。

  • k (int | ArrayLike) – 指定对角线偏移量的可选整数。对上对角线使用正值,对下对角线使用负值。默认为零。

  • dtype (DTypeLike | None | None) – 可选的 dtype;默认为浮点数。

  • 设备 (xc.Device | 分片 | None | None) – 可选 DeviceSharding,将创建的数组提交到该设备或分片。

返回值:

形状为 (N, M) 的单位矩阵,如果未指定 M,则形状为 (N, N)

返回类型:

数组

另请参阅

jax.numpy.identity(): 用于生成方单位矩阵的更简单 API。

示例

一个简单的 3x3 单位矩阵

>>> jnp.eye(3)
Array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)

带偏移对角线的整数单位矩阵

>>> jnp.eye(3, k=1, dtype=int)
Array([[0, 1, 0],
       [0, 0, 1],
       [0, 0, 0]], dtype=int32)
>>> jnp.eye(3, k=-1, dtype=int)
Array([[0, 0, 0],
       [1, 0, 0],
       [0, 1, 0]], dtype=int32)

非方单位矩阵

>>> jnp.eye(3, 5, k=1)
Array([[0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.]], dtype=float32)