jax.numpy.ix_

内容

jax.numpy.ix_#

jax.numpy.ix_(*args)[source]#

从 N 个一维序列返回一个多维网格(开放网格)。

numpy.ix_() 的 JAX 实现。

参数:

*args (ArrayLike) – N 个一维数组

返回:

形成开放网格的 Jax 数组元组,每个数组都有 N 个维度。

返回类型:

tuple[Array, …]

示例

>>> rows = jnp.array([0, 2])
>>> cols = jnp.array([1, 3])
>>> open_mesh = jnp.ix_(rows, cols)
>>> open_mesh
(Array([[0],
      [2]], dtype=int32), Array([[1, 3]], dtype=int32))
>>> [grid.shape for grid in open_mesh]
[(2, 1), (1, 2)]
>>> x = jnp.array([[10, 20, 30, 40],
...                [50, 60, 70, 80],
...                [90, 100, 110, 120],
...                [130, 140, 150, 160]])
>>> x[open_mesh]
Array([[ 20,  40],
       [100, 120]], dtype=int32)