jax.experimental.pallas.BlockSpec

jax.experimental.pallas.BlockSpec#

class jax.experimental.pallas.BlockSpec(block_shape=None, index_map=None, *, memory_space=None, indexing_mode=Blocked)[source]#

指定数组在每次内核调用时如何切片。

有关更多详细信息,请参阅 BlockSpec,即如何将输入分成块

参数:
  • block_shape (元组[整数 | , ...] | )

  • index_map (可调用[..., 任何] | )

  • memory_space (任何 | )

  • indexing_mode (IndexingMode)

__init__(block_shape=None, index_map=None, *, memory_space=None, indexing_mode=Blocked)[源代码]#
参数:
  • block_shape (任何 | | )

  • index_map (任何 | | )

  • memory_space (任何 | | )

  • indexing_mode (IndexingMode)

返回类型:

方法

__init__([block_shape, index_map, ...])

to_block_mapping(origin, array_aval, *, ...)

属性

block_shape

index_map

indexing_mode

memory_space