jax.numpy.diagflat#
- jax.numpy.diagflat(v, k=0)[源代码]#
返回一个二维数组,其中扁平化的输入数组沿对角线排列。
numpy.diagflat()
的 JAX 实现。这与 np.diagflat 对于某些标量值的 v 有所不同。JAX 始终返回一个二维数组,而 NumPy 可能根据 v 的类型返回一个标量。
- 参数:
v (ArrayLike) – 输入数组。可以是 N 维的,但会被展平为 1 维。
k (int) – 可选参数,默认为 0。对角线偏移量。正值将对角线置于主对角线上方,负值将对角线置于主对角线下方。
- 返回:
一个二维数组,其输入元素沿指定偏移量 (k) 的对角线放置。其余条目填充为零。
- 返回类型:
示例
>>> jnp.diagflat(jnp.array([1, 2, 3])) Array([[1, 0, 0], [0, 2, 0], [0, 0, 3]], dtype=int32) >>> jnp.diagflat(jnp.array([1, 2, 3]), k=1) Array([[0, 1, 0, 0], [0, 0, 2, 0], [0, 0, 0, 3], [0, 0, 0, 0]], dtype=int32) >>> a = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.diagflat(a) Array([[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]], dtype=int32)