jax.numpy.ogrid#
- jax.numpy.ogrid = <jax._src.numpy.index_tricks._Ogrid object>#
返回开放式多维“meshgrid”。
numpy.ogrid
的 LAX 后端实现。这是jax.numpy.meshgrid()
,其中sparse=True
提供的功能的便捷包装器。另请参阅
jnp.mgrid: jnp.ogrid 的密集版本
示例
传递
[start:stop:step]
来生成类似于jax.numpy.arange()
的值>>> jnp.ogrid[0:4:1] Array([0, 1, 2, 3], dtype=int32)
传递虚数步长生成的值类似于
jax.numpy.linspace()
>>> jnp.ogrid[0:1:4j] Array([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32)
可以使用多个切片创建稀疏的索引网格
>>> jnp.ogrid[:2, :3] [Array([[0], [1]], dtype=int32), Array([[0, 1, 2]], dtype=int32)]