jax.numpy.ogrid

内容

jax.numpy.ogrid#

jax.numpy.ogrid = <jax._src.numpy.index_tricks._Ogrid object>#

返回开放的多维“网格”。

LAX 后端实现的 numpy.ogrid。这是一个针对 jax.numpy.meshgrid() 提供的功能的便捷包装器,其中 sparse=True

参见

jnp.mgrid: jnp.ogrid 的密集版本

示例

传递 [start:stop:step] 以生成类似于 jax.numpy.arange() 的值

>>> jnp.ogrid[0:4:1]
Array([0, 1, 2, 3], dtype=int32)

传递虚数步长会生成类似于 jax.numpy.linspace() 的值

>>> jnp.ogrid[0:1:4j]
Array([0.        , 0.33333334, 0.6666667 , 1.        ], dtype=float32)

可以使用多个切片创建稀疏索引网格

>>> jnp.ogrid[:2, :3]
[Array([[0],
        [1]], dtype=int32),
 Array([[0, 1, 2]], dtype=int32)]