jax.Array.view#

抽象 Array.view(dtype=None, type=None)[源代码]#

返回数组的按位副本,视为新的 dtype。

这是对 jax.lax.bitcast_convert_type() 更完整功能的包装。

如果源和目标 dtype 具有相同的位宽,则结果的形状与输入数组相同。 如果目标 dtype 的位宽与源不同,则相应地调整结果的最后一个轴的大小。

>>> jnp.zeros([1,2,3], dtype=jnp.int16).view(jnp.int8).shape
(1, 2, 6)
>>> jnp.zeros([1,2,4], dtype=jnp.int8).view(jnp.int16).shape
(1, 2, 2)

在所有情况下,涉及布尔值的转换都未明确定义。关于如上所述的结果形状,布尔值被视为具有 8 位的位宽。然而,当转换为布尔数组时,输入应仅包含 0 或 1 字节。否则,结果可能是不可预测的,或者可能会根据结果的使用方式而改变。

此转换是保证且安全的

>>> jnp.array([1, 0, 1], dtype=jnp.int8).view(jnp.bool_)
Array([ True, False,  True], dtype=bool)

但是,对于涉及视图的任何表达式的结果,例如:jnp.array([1, 2, 3], dtype=jnp.int8).view(jnp.bool_),没有任何保证。 特别是,结果可能会在 JAX 版本之间以及根据平台而变化。为了安全地将这样的数组转换为布尔数组,请将其与 0 进行比较

>>> jnp.array([1, 2, 0], dtype=jnp.int8) != 0
Array([ True,  True, False], dtype=bool)
参数:
  • self (Array)

  • dtype (DTypeLike | None)

  • type (None)

返回类型:

Array