jax.Array.astype

jax.Array.astype#

abstract Array.astype(dtype, copy=False, device=None)[source]#

复制数组并转换为指定的 dtype。

这是通过 jax.lax.convert_element_type() 实现的,在某些情况下,它的行为可能与 numpy.ndarray.astype() 略有不同。特别是,浮点数到整数和整数到浮点数转换的细节取决于实现。

参数:
  • self (Array)

  • dtype (DTypeLike | None)

  • copy (bool)

  • device (xc.Device | Sharding | None)

返回类型:

Array