jax.experimental.pallas.pallas_call#

jax.experimental.pallas.pallas_call(kernel, out_shape, *, grid_spec=None, grid=(), in_specs=NoBlockSpec, out_specs=NoBlockSpec, scratch_shapes=(), input_output_aliases={}, debug=False, interpret=False, name=None, compiler_params=None, cost_estimate=None, backend=None)[源代码]#

在某些输入上调用 Pallas 内核。

请参阅 Pallas 快速入门

参数:
  • kernel (Callable[..., None]) – 内核函数,它接收每个输入和输出的 Ref。Refs 的形状由对应的 in_specsout_specs 中的 block_shape 给出。

  • out_shape (Any) – 一个 jax.ShapeDtypeStruct 的 PyTree,描述输出的形状和数据类型。

  • grid_spec (GridSpec | None | None) – 指定 gridin_specsout_specsscratch_shapes 的另一种方法。如果给定,则不得同时给出其他参数。

  • grid (TupleGrid) – 迭代空间,表示为整数元组。内核执行的次数为 prod(grid)。请参阅 网格,又名循环中的内核 了解详细信息。

  • in_specs (BlockSpecTree) – 一个 jax.experimental.pallas.BlockSpec 的 PyTree,其结构与位置参数的结构匹配。in_specs 的默认值指定所有输入的整个数组,例如 pl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim)。请参阅 BlockSpec,又名如何对输入进行分块 了解详细信息。

  • out_specs (BlockSpecTree) – 一个 jax.experimental.pallas.BlockSpec 的 PyTree,其结构与输出的结构匹配。out_specs 的默认值指定整个数组,例如 pl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim)。请参阅 BlockSpec,又名如何对输入进行分块 了解详细信息。

  • scratch_shapes (ScratchShapeTree) – 内核所需的后端特定临时对象 PyTree,例如临时缓冲区、同步原语等。

  • input_output_aliases (dict[int, int]) – 一个字典,将某些输入的索引映射到别名它们的输出的索引。这些索引位于扁平化的输入和输出中。

  • debug (bool) – 如果为 True,则 Pallas 会在处理内核时打印各种中间形式。

  • interpret (bool) – 将 pallas_call 作为扫描网格的 jax.jit 运行,该网格的主体是作为 JAX 函数降低的内核。这不需要 TPU 或 GPU,并且是在 CPU 上运行 Pallas 内核的唯一方法。这对于调试很有用。

  • name (str | None | None) – 如果存在,则指定在调试和错误消息中用于此内核调用的名称。在此名称中,我们附加定义内核函数的文件和行,例如:{name} for kernel function {kernel_name} at {file}:{line}。如果缺失,则我们使用 {kernel_name} at {file}:{line}

  • compiler_params (dict[str, Any] | pallas_core.CompilerParams | None | None) – 可选的编译器参数。如果提供 dict,则其形式应为 {platform: {param_name: param_value}},其中 platform 是 ‘mosaic’ 或 ‘triton’。也可以为 TPU 传入 jax.experimental.pallas.tpu.TPUCompilerParams,为 Triton/GPU 传入 jax.experimental.pallas.gpu.TritonCompilerParams

  • backend (_Backend | None | None) – 可选的字符串文字,为 “mosaic_tpu”、“triton” 或 “mosaic_gpu” 中的一个,确定要使用的后端。None 表示让 pallas 决定。

  • cost_estimate (CostEstimate | None | None)

返回:

一个可以调用任意数量的位置数组参数来调用 Pallas 内核的函数。

返回类型:

Callable[…, Any]