jax.ShapeDtypeStruct

jax.ShapeDtypeStruct#

class jax.ShapeDtypeStruct(shape, dtype, *, sharding=None, weak_type=False)[source]#

数组的形状、数据类型和其他静态属性的容器。

ShapeDtypeStruct 通常与 jax.eval_shape() 一起使用。

参数:
  • shape – 表示数组形状的整数序列

  • dtype – 类 dtype 的对象

  • sharding – (可选) jax.Sharding 对象

__init__(shape, dtype, *, sharding=None, weak_type=False)[source]#

方法

__init__(shape, dtype, *[, sharding, weak_type])

属性

shape

dtype

sharding

weak_type

layout

ndim

size