jax.numpy.meshgrid#

jax.numpy.meshgrid(*xi, copy=True, sparse=False, indexing='xy')[源代码]#

从 N 个一维向量构造 N 维网格数组。

numpy.meshgrid() 的 JAX 实现。

参数:
  • xi (ArrayLike) – 要转换为网格的 N 个数组。

  • copy (bool) – 是否复制输入数组。JAX 仅支持 copy=True,尽管在 JIT 编译下,编译器可能会选择避免复制。

  • sparse (bool) – 如果为 False(默认),则每个返回的数组的形状为 [len(x1), len(x2), ..., len(xN)]。如果为 True,则返回的数组的形状为 [1, 1, ..., len(xi), ..., 1, 1]

  • indexing (str) – 选项为 'xy' 表示笛卡尔索引(默认)或 'ij' 表示矩阵索引。

返回:

长度为 N 的网格数组列表。

返回类型:

list[Array]

另请参阅

示例

对于以下示例,我们将使用这些 1D 数组作为输入

>>> x = jnp.array([1, 2])
>>> y = jnp.array([10, 20, 30])

2D 笛卡尔网格

>>> x_grid, y_grid = jnp.meshgrid(x, y)
>>> print(x_grid)
[[1 2]
 [1 2]
 [1 2]]
>>> print(y_grid)
[[10 10]
 [20 20]
 [30 30]]

2D 稀疏笛卡尔网格

>>> x_grid, y_grid = jnp.meshgrid(x, y, sparse=True)
>>> print(x_grid)
[[1 2]]
>>> print(y_grid)
[[10]
 [20]
 [30]]

2D 矩阵索引网格

>>> x_grid, y_grid = jnp.meshgrid(x, y, indexing='ij')
>>> print(x_grid)
[[1 1 1]
 [2 2 2]]
>>> print(y_grid)
[[10 20 30]
 [10 20 30]]