jax.numpy.ix_#
- jax.numpy.ix_(*args)[source]#
从 N 个一维序列返回一个多维网格(开放网格)。
numpy.ix_()
的 JAX 实现。示例
>>> rows = jnp.array([0, 2]) >>> cols = jnp.array([1, 3]) >>> open_mesh = jnp.ix_(rows, cols) >>> open_mesh (Array([[0], [2]], dtype=int32), Array([[1, 3]], dtype=int32)) >>> [grid.shape for grid in open_mesh] [(2, 1), (1, 2)] >>> x = jnp.array([[10, 20, 30, 40], ... [50, 60, 70, 80], ... [90, 100, 110, 120], ... [130, 140, 150, 160]]) >>> x[open_mesh] Array([[ 20, 40], [100, 120]], dtype=int32)