jax.experimental.pallas.MemoryRef#

class jax.experimental.pallas.MemoryRef(shape, dtype, memory_space)[源代码]#

类似于 jax.ShapeDtypeStruct,但带有内存空间。

参数:
  • shape (tuple[int, ...])

  • dtype (jnp.dtype)

  • memory_space (Any)

__init__(shape, dtype, memory_space)#
参数:
  • shape (tuple[int, ...])

  • dtype (jnp.dtype)

  • memory_space (Any)

返回类型:

None

方法

__init__(shape, dtype, memory_space)

get_array_aval()

get_ref_aval()

属性

shape

dtype

memory_space