jax.numpy.meshgrid#
- jax.numpy.meshgrid(*xi, copy=True, sparse=False, indexing='xy')[源代码]#
从 N 个一维向量构造 N 维网格数组。
numpy.meshgrid()
的 JAX 实现。- 参数:
- 返回:
长度为 N 的网格数组列表。
- 返回类型:
另请参阅
jax.numpy.indices()
: 生成索引网格。jax.numpy.mgrid
: 使用索引语法创建网格。jax.numpy.ogrid
: 使用索引语法创建开放网格。
示例
对于以下示例,我们将使用这些 1D 数组作为输入
>>> x = jnp.array([1, 2]) >>> y = jnp.array([10, 20, 30])
2D 笛卡尔网格
>>> x_grid, y_grid = jnp.meshgrid(x, y) >>> print(x_grid) [[1 2] [1 2] [1 2]] >>> print(y_grid) [[10 10] [20 20] [30 30]]
2D 稀疏笛卡尔网格
>>> x_grid, y_grid = jnp.meshgrid(x, y, sparse=True) >>> print(x_grid) [[1 2]] >>> print(y_grid) [[10] [20] [30]]
2D 矩阵索引网格
>>> x_grid, y_grid = jnp.meshgrid(x, y, indexing='ij') >>> print(x_grid) [[1 1 1] [2 2 2]] >>> print(y_grid) [[10 20 30] [10 20 30]]