jax.numpy.eye#
- jax.numpy.eye(N, M=None, k=0, dtype=None, *, device=None)[源代码]#
创建一个方形或矩形单位矩阵
numpy.eye()
的 JAX 实现。- 参数:
- 返回:
形状为
(N, M)
的单位数组,或者如果未指定M
,则为(N, N)
。- 返回类型:
另请参阅
jax.numpy.identity()
: 用于生成方形单位矩阵的更简单的 API。示例
一个简单的 3x3 单位矩阵
>>> jnp.eye(3) Array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32)
带有偏移对角线的整数单位矩阵
>>> jnp.eye(3, k=1, dtype=int) Array([[0, 1, 0], [0, 0, 1], [0, 0, 0]], dtype=int32) >>> jnp.eye(3, k=-1, dtype=int) Array([[0, 0, 0], [1, 0, 0], [0, 1, 0]], dtype=int32)
非方形单位矩阵
>>> jnp.eye(3, 5, k=1) Array([[0., 1., 0., 0., 0.], [0., 0., 1., 0., 0.], [0., 0., 0., 1., 0.]], dtype=float32)