标量预取和块稀疏计算#

在本教程中,我们将介绍 Pallas 中块稀疏计算的基础知识。稀疏计算是编写自定义 Pallas 内核而不是简单使用 JAX/XLA 的主要原因,因为由于静态数组形状,通常很难在 XLA 中表达执行动态计算量的程序。在本教程中,我们将学习如何使用 Pallas 的标量预取功能来编写块稀疏内核,该内核可以动态跳过计算和内存块。

import functools
import timeit
import numpy as np
import jax
from jax import numpy as jnp
from jax import lax
from jax.experimental import checkify
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu

assert "TPU" in jax.devices()[0].device_kind, "Please run this notebook with TPU devices."
print("Running on", jax.devices()[0].device_kind)
Running on TPU v5 lite

使用标量预取进行动态块索引#

我们将利用 Pallas 的“标量预取”功能,使我们能够编写稀疏内核。标量预取允许您将少量数据传递到 SMEM(“标量内存”),这些数据在管道开始之前加载(“预取”)。由于此数据在管道之前加载,因此可以在每个 BlockSpec 的 index_map 中使用,从而允许您执行依赖于数据的索引计算。本教程的主要目标是回顾利用此功能的常见编程模式。

要使用标量预取,请使用 pltpu.PrefetchScalarGridSpec 代替标准的 pl.GridSpec

class PrefetchScalarGridSpec:
  def __init__(self,
    num_scalar_prefetch: int,
    grid: tuple[int, ...],
    in_specs: PyTree[BlockSpec],
    out_specs: PyTree[BlockSpec],
    scratch_shapes: tuple[MemorySpace, ...]):
      ...

num_scalar_prefetch 参数指示标量预取值的数量。当此值设置为非零值时,它会更改内核和索引映射的调用签名,以期望额外的预取值。传递给 index_map 和内核的预取 Ref 都分配在 SMEM 中,并且不会像没有定义 BlockSpec 那样被划分为块。此外,index_map 和内核的参数顺序始终是固定的,如下所述

  • 每个 BlockSpecindex_map 现在期望预取 Ref 在网格索引之后出现

def index_map(*grid_indices, *prefetch_refs):
    ...
  • 用户定义的内核期望预取 Ref 在输入 Ref 之前出现。此外,暂存引用在输出 Ref 之后出现。

def kernel(*prefetch_refs, *input_refs, *output_refs, *scratch_refs):
    ...
  • 当使用 pallas_call 调用新内核时,pallas_call 返回的函数也期望标量预取参数在输入之前出现,例如

kernel = pl.pallas_call(...)
result = kernel(*prefetch_args, *input_args)

示例:使用标量预取的块动态切片#

让我们从一个基本示例开始,演示如何使用标量预取功能。我们将实现一个块对齐的动态切片内核,该内核只需根据用户指定的索引从较大的数组中提取一个块

  1. 在内核之外,我们将要提取的块索引计算为: block_idx = (start[0] // size[0], start[1] // size[1])

  2. 我们将 block_idx 作为标量预取参数传递到 pallas_call 中。

  3. 在我们的索引映射中,我们使用块索引通过返回 (block_idx[0], block_idx[1]) 来选择相应的块。

当然,此内核的局限性在于我们的切片大小必须适合内核块内部(受 VMEM 大小的限制),并且我们只能在大小对齐的索引上开始。更高级的内核会将内核块大小与切片大小分离,并允许非对齐的起始索引。

def dynamic_slice_kernel(indices, x_ref, o_ref):
  del indices
  o_ref[...] = x_ref[...]

@checkify.checkify
@functools.partial(jax.jit, static_argnums=(2,))
def block_dynamic_slice(x, starts, sizes):
  grid_spec = pltpu.PrefetchScalarGridSpec(
      num_scalar_prefetch=1,
      grid=(1, 1),
      in_specs=[pl.BlockSpec(
          sizes,
          lambda i, j, block_idx: (block_idx[0], block_idx[1]))],
      out_specs=pl.BlockSpec(sizes, lambda *_: (0, 0)),
  )

  kernel = pl.pallas_call(
    dynamic_slice_kernel,
    grid_spec=grid_spec,
    out_shape=jax.ShapeDtypeStruct(shape=sizes, dtype=x.dtype),
  )
  # Checkify inserts a runtime assert that starts are divisible by block size.
  checkify.check(starts[0] % sizes[0] == 0, "Starts must be divisible by size.")
  checkify.check(starts[1] % sizes[1] == 0, "Starts must be divisible by size.")
  block_idx = jnp.array([starts[0] // sizes[0], starts[1] // sizes[1]])
  return kernel(block_idx, x)

shape = (512, 512)
x = jnp.reshape(jnp.arange(np.prod(shape), dtype=jnp.int32), shape)
err, result = block_dynamic_slice(x, starts=(128, 256), sizes=(128, 128))
err.throw()
ref = lax.dynamic_slice(x, start_indices=(128, 256), slice_sizes=(128, 128))
diff = jnp.max(jnp.abs(result - ref))
print("Error |result - lax.dynamic_slice| =", diff)
Error |result - lax.dynamic_slice| = 0

稀疏内核:表示稀疏数据#

在我们深入研究实现稀疏内核之前,让我们首先回顾一下稀疏矩阵是如何表示的。虽然有几种流行的格式用于存储稀疏矩阵,但我们将遵循坐标列表格式 (COO) 的块变体,其中我们将矩阵存储为 (block_index, block_data) 对的列表。所有未在列表中显式存储的块都假定为零,这意味着如果矩阵中有许多零块,我们可以节省大量内存。

下图演示了如何将 4x4 密集矩阵(左)转换为块 COO 格式(右),块大小为 2x2。请注意,在稀疏格式中,我们可以避免显式存储由所有零元素组成的右上角块。

block_coo

我们将使用以下辅助函数来采样块稀疏矩阵。它返回一个用于检查结果的密集矩阵,以及每个轴的块数据和索引列表。

def generate_block_sparse_mat(key, M, N, blk_M, blk_N, p=0.2, dtype=jnp.float32):
  """Returns a sampled matrix and its block-sparse representation.

  Args:
    key: RNG Key.
    M: Major array dimension.
    N: Minor array dimension.
    blk_M: Block size along M dimension.
    blk_N: Block size along N dimension.
    p: Probability that a block will be non-zero.
    dtype: dtype of the sampled matrix.

  Returns:
    dense_mat: A (M, N) dense sampled array.
    block_data: A (num_blocks, blk_M, blk_N) array of data blocks representing
      the non-zero blocks of the matrix.
    indices_i: A (num_blocks,) array of block indices for the first axis.
    indices_j: A (num_blocks,) array of block indices for the second axis.
  """
  mask_key, blocks_key = jax.random.split(key)
  num_blocks = (M // blk_M, N // blk_N)
  # We first sample a block mask, denoting which blocks are nonzero.
  block_mask = jax.random.bernoulli(mask_key, p=p, shape=num_blocks)
  num_blocks = jnp.sum(block_mask)
  indices = jnp.where(block_mask)
  # For each non-zero block, we sample a block of random values.
  block_data = jax.random.uniform(blocks_key,
                                  shape=(num_blocks, blk_M, blk_N),
                                  dtype=dtype)
  # For checking purposes, create the dense version of the sparse matrix.
  dense_mat = jnp.zeros((M, N), dtype=dtype)
  for blk in range(num_blocks):
    idx_i = indices[0][blk]
    idx_j = indices[1][blk]
    slice_i = slice(idx_i * blk_M, (idx_i + 1) * blk_M)
    slice_j = slice(idx_j * blk_N, (idx_j + 1) * blk_N)
    dense_mat = dense_mat.at[slice_i, slice_j].set(block_data[blk])
  return dense_mat, block_data, indices[0], indices[1]

示例:稀疏 @ 密集矩阵乘法#

在我们的第一个示例中,我们将稀疏 LHS 矩阵与密集 RHS 矩阵相乘,以产生密集输出。

我们将使用 2 个循环来构建我们的内核网格 - 外循环遍历 RHS/输出的列,内循环遍历 LHS 的稀疏块。在每次内部循环迭代期间,我们从 LHS 加载一个块,并使用收缩维度 (K) 的块索引在 RHS 中查找相应的块。我们将两个块相乘,并累积到正确的输出块中。一个外部循环迭代将计算整个列的结果,如下图所示

sparse_matmul

重要的是,我们在将块索引传递到内核之前按行分组块索引(例如,[0, 0, 1, 2, 3, 3]),原因有两个。首先,在我们的内核中,我们需要知道何时最初将累加器清零到输出引用中,如果行索引分组,则很容易做到这一点。其次,Pallas 的管道逻辑不允许我们在非连续迭代中重新访问输出 Ref 中的块,因此我们需要在连续的内核迭代中对输出块进行所有累积。这是因为管道发射器会意识到我们在连续迭代中加载相同的输出块,并将该块保留在 VMEM 中。当我们更改输出块时,Pallas 最终会将输出存储到 HBM 中,并假设我们再也不会触摸它。即使内核在其他方面逻辑正确,未能连续访问输出块也会导致值不正确。

M = N = K = 16384
blk_M = blk_N = blk_K = 512


def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs.
               x_ref, y_ref, _, o_ref, # Kernel inputs.
               accum_scratch,
               ):
  """A DSD (Dense = Sparse @ Dense) matmul kernel."""
  del idxs_k_ref
  blk_idx = pl.program_id(0)
  is_start = blk_idx == 0
  changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)])
  @pl.when(is_start | changed_blocks)
  def _():
    accum_scratch[...] = jnp.zeros_like(accum_scratch)
  accum_scratch[...] += jnp.dot(x_ref[0, :, :], y_ref[...], preferred_element_type=jnp.float32)

  next_block_change = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.minimum(blk_idx+1, num_blocks)])
  is_end = blk_idx == (num_blocks - 1)
  @pl.when(is_end | next_block_change)
  def _():
    o_ref[...] = accum_scratch[...].astype(o_ref.dtype)


def x_map(blk_idx, j, blk_idxs_i, blk_idxs_k):
  del j, blk_idxs_i, blk_idxs_k
  return (blk_idx, 0, 0)
def y_map(blk_idx, j, blk_idxs_i, blk_idxs_k):
  del blk_idxs_i
  return (blk_idxs_k[blk_idx], j)
def o_map(blk_idx, j, blk_idxs_i, blk_idxs_k):
  del blk_idxs_k
  return (blk_idxs_i[blk_idx], j)

(X_dense, X_blocks, indices_i, indices_k) = generate_block_sparse_mat(
    jax.random.key(0), M, K, blk_M, blk_K, p=0.1, dtype=jnp.bfloat16)
num_blocks = X_blocks.shape[0]
Y = jax.random.uniform(jax.random.key(1), shape=(K, N), dtype=jnp.bfloat16)
zeros = jnp.zeros((M, N), dtype=jnp.bfloat16)
out_shape = jax.ShapeDtypeStruct((M, N), dtype=jnp.bfloat16)

grid_spec = pltpu.PrefetchScalarGridSpec(
    num_scalar_prefetch=2,
    # Note that while num_blocks is static here, Pallas does support
    # dynamic grid sizes.
    grid=(num_blocks, N // blk_N),
    in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),
              pl.BlockSpec((blk_K, blk_N), y_map),
              # Placeholder for a zeros-array used by input_output_aliases.
              pl.BlockSpec((blk_M, blk_N), o_map),
              ],
    out_specs=pl.BlockSpec((blk_M, blk_N), o_map),
    scratch_shapes=[pltpu.VMEM((blk_M, blk_N), dtype=jnp.float32)]
)
kernel = pl.pallas_call(
  dsd_kernel,
  grid_spec=grid_spec,
  out_shape=out_shape,
  # We use input-output aliases to zero-out o_ref for blocks that we never
  # visit. By passing in an array of zeros we avoid having o_ref start with
  # uninitialized values.
  input_output_aliases={4: 0},  # Map zeros to o_ref.
)
args = (indices_i, indices_k, X_blocks, Y, zeros)
result = kernel(*args)

ref = X_dense @ Y
diff = jnp.abs(ref - result)
print('mean |result - ref|:', jnp.mean(diff))
mean |result - ref|: 0

我们可以进行快速基准测试,以比较稀疏内核与 JAX 中密集矩阵乘法的性能。在 TPU v5e 芯片上,此内核的性能比稀疏因子带来的理论 10 倍速度提高了约 6 倍。

这里有一些主要的性能提示,主要集中在减少 HBM/VMEM 之间的通信开销上

  • 使用 dtype=jnp.bfloat16 对性能至关重要,因为它将内存带宽减少了一半。

  • 使用较大的块大小也有帮助,因为矩阵乘法是 \(O(N^3)\) 计算和 \(O(N^2)\) 内存操作。随着 \(N\) 变大,内核变为计算密集型。然而,在实践中,对此的一个反驳是较小的块大小也使得数据更稀疏,因此这是一个应该仔细选择的参数。

# Benchmark Sparse Pallas kernel vs reference JAX implementation

def benchmark(f, ntrials: int = 100):
  def run(*args, **kwargs):
    # Compile function first
    jax.block_until_ready(f(*args, **kwargs))
    # Time function
    result = timeit.timeit(lambda: jax.block_until_ready(f(*args, **kwargs)),
                           number=ntrials)
    time = result / ntrials
    return time
  return run


n_trials = 100

pallas_impl = lambda *args: kernel(*args)
time = benchmark(pallas_impl, n_trials)(indices_i, indices_k, X_blocks, Y, zeros)
print("Sparse Kernel: %.3f ms (avg over %d trials)" % (time * 1000, n_trials))

ref_impl = jax.jit(lambda x, y: x @ y)
time = benchmark(ref_impl, n_trials)(X_dense, Y)
print("Reference: %.3f ms (avg over %d trials)" % (time * 1000, n_trials))
Sparse Kernel: 8.136 ms (avg over 100 trials)
Reference: 46.953 ms (avg over 100 trials)

密集数据上的稀疏访问模式#

在我们之前的示例中,我们考虑了数据本身是稀疏的情况。这在内核结构中表现为一个内核网格中的维度,它是动态的,并在非零块的数量 (num_blocks) 上循环。

当底层数据密集时,我们希望对其执行稀疏计算时,会出现第二种有用的编程模式。在这种情况下,我们的内核网格将是密集的,但是我们希望按照块稀疏掩码的指示跳过网格中的某些块。当在许多机器学习应用程序中使用掩码时,例如自注意力中的因果掩码或局部掩码时,通常会出现这种类型的编程模式。在这些情况下,我们可以完全跳过掩码为零的块中的计算。这种编程模式的示例可以在 jax/experimental/pallas/ops/tpu 中的 Splash Attention 和 Grouped Matrix Multiplication 内核中找到,或者在 PyTorch 的 FlexAttention 中找到。

在处理密集数据上的稀疏访问模式时,主要的性能考虑因素是与流水线的交互。在任何给定的内核迭代中,Pallas 管道发射器都会尝试通过在网格的下一次迭代中为每个 BlockSpec 调用 index_map 来预取下一个数据块。但是,如果我们的计算是稀疏的,我们可能会跳过网格中下一个块的计算,因此我们需要一些方法来告诉管道开始获取*我们没有跳过的下一个块*。为了做到这一点,我们需要构造*预取映射*,其中包含每个内核输入的下一个非跳过数据块的索引。下图说明了如何为以类似 COO 格式存储的块稀疏掩码构造预取映射。

prefetch_map

左图:稀疏访问模式,其中蓝色表示我们需要计算的非零掩码的块。右图:预取映射,其中数组的每个元素都包含下一个非零块数据的索引。

构造好预取映射后,我们可以将该映射作为标量预取参数传递,并在 BlockSpec 的 index_map 函数中查询它。

def mask_index_map(prefetch_map, i, j, ...):
  next_nonzero_block = prefetch_map[i, j]
  return (next_nonzero_block, 0, 0)

我们可以为内核的其他输入构造类似的索引映射。对于密集输入,您很可能需要构造预取映射,这些映射指向网格中下一个非零块索引。我们的下一个示例将提供使用这些预取映射的示例。

示例:具有块稀疏输出掩码的密集 @ 密集矩阵乘法#

在我们的下一个示例中,我们将介绍使用预取映射融合稀疏输出掩码的密集矩阵乘法,以提高流水线性能。我们将使用掩码来选择性地跳过计算被清零的输出块,从而节省计算成本。

由于我们将使用稀疏掩码,因此我们将首先实现一个函数,该函数将以密集格式存储的 N x M 掩码转换为块稀疏格式。此外,我们需要计算预取映射,以帮助管道发射器知道接下来要获取哪个块。总而言之,我们的 sparsify_mask 函数计算

  • 一个形状为 (num_N_blocks, num_M_blocks)block_mask,指示块是否全为零(值 0)或包含非零元素(值 1)。如果 block_mask 的值为 0,我们可以跳过在内核中计算该块。

  • 一个形状为 (num_N_blocks, num_M_blocks)prefetch_mask 数组,由指向 mask_data 中下一个非零块的索引组成。

  • 一个形状为 (num_N_blocks, num_M_blocks)prefetch_i 数组,由掩码中下一个非掩码的 i 索引组成。

  • 一个形状为 (num_N_blocks, num_M_blocks)prefetch_j 数组,由掩码中下一个非掩码的 j 索引组成。

  • 一个形状为 (num_blocks, blk_N, blk_M)mask_data 数组,包含掩码中非零块的数据。

def sparsify_mask(mask: jax.Array,
                  block_shape: tuple[int, int]):
  """Preprocesses a mask into a sparse reprentation.

  Args:
    mask: A boolean array of shape [M, N]
    block_shape: The size of a single block.

  Returns:
    block_mask: A block_shape array of booleans indicating whether a block
      is all-zeros (0) or contains non-zero elements (1).
    prefetch_mask: A block_shape array of integers indicating the index of the
      next non-zero block.
    mask_data: A (num_blocks, block_shape) array containing
      the data for non-zero blocks of the mask.
  """
  M, N = mask.shape
  bm, bn = block_shape

  block_mask = jnp.zeros((M // bm, N // bn), dtype=mask.dtype)
  mask_types_finder = []
  mask_data = []
  mask_type_idxs = []

  next_mask_type_idx = 0
  prefetch_mask = jnp.zeros_like(block_mask)
  next_i = (M // bm) - 1
  next_j = (N // bn) - 1
  prefetch_i = jnp.zeros_like(block_mask)
  prefetch_j = jnp.zeros_like(block_mask)
  for i in range(M // bm, -1, -1):
    for j in range(N // bn, -1, -1):
      mask_block = mask[i * bm :(i + 1) * bm,
                        j * bn :(j + 1) * bn]
      is_nonzero = jnp.any(mask_block)
      if is_nonzero:
        try:
          type_index = mask_types_finder.index(str(mask_block))
        except ValueError:
          type_index = len(mask_types_finder)
          mask_types_finder.append(str(mask_block))
          mask_data.append(mask_block)
        next_mask_type_idx = type_index
        next_i = i
        next_j = j
      else:
        type_index = -1
      mask_type_idxs.append(type_index)
      block_mask = block_mask.at[i, j].set(is_nonzero)
      prefetch_mask = prefetch_mask.at[i, j].set(next_mask_type_idx)
      prefetch_i = prefetch_i.at[i, j].set(next_i)
      prefetch_j = prefetch_j.at[i, j].set(next_j)
  return block_mask, prefetch_mask, prefetch_i, prefetch_j, jnp.stack(mask_data)

在内核的结构方面,我们使用与之前教程中介绍的标准矩阵乘法内核相同的网格模式,即在 NMK 维度上进行 3 个循环。在内核内部,我们首先检查 block_mask,以查看当前输出块的掩码是否全为零。如果掩码全为零,我们可以跳过计算并移动到下一个块;否则,我们需要计算矩阵乘法,然后对结果进行掩码。

M = N = K = 16384
blk_M = blk_N = 512
blk_K = 1024

def sparse_mask_matmul(
    block_mask_ref, prefetch_mask, prefetch_i, prefetch_j, # Scalar prefetch inputs.
    x_ref, y_ref, mask_ref, o_ref,  # Kernel inputs.
    accum_scratch
    ):
  del prefetch_mask, prefetch_i, prefetch_j
  i, j, k = pl.program_id(0), pl.program_id(1), pl.program_id(2)
  should_compute = block_mask_ref[i, j] != 0
  @pl.when(k == 0)
  def _():
    o_ref[...] = jnp.zeros_like(o_ref)
    accum_scratch[...] = jnp.zeros_like(accum_scratch[...])

  # We only compute the output for blocks with non-zero masks.
  # Otherwise we skip the computation entirely.
  @pl.when(should_compute)
  def _():
    result = jnp.dot(x_ref[...], y_ref[...], preferred_element_type=jnp.float32)
    accum_scratch[...] += result
    @pl.when(k == pl.num_programs(2) - 1)
    def _():
      o_ref[...] = (mask_ref[0, ...] * accum_scratch[...]).astype(o_ref.dtype)

X = jax.random.normal(jax.random.key(0), shape=(M, K), dtype=jnp.bfloat16)
Y = jax.random.normal(jax.random.key(1), shape=(K, N), dtype=jnp.bfloat16)
mask = jnp.ones((M, N), dtype=jnp.int32)
mask = jnp.tril(mask)
block_mask, prefetch_mask, prefetch_i, prefetch_j, sparse_mask_data = sparsify_mask(mask, (blk_M, blk_N))

def x_map(i, j, k, block_mask, prefetch_mask, prefetch_i, prefetch_j):
  del prefetch_mask, prefetch_j
  # Zero-out the k index if the mask is zero, to avoid constantly fetching
  # new blocks in the inner loop for blocks we are skipping.
  k_fetch = (block_mask[i, j] != 0) * k
  return (prefetch_i[i, j], k_fetch)

def y_map(i, j, k, block_mask, prefetch_mask, prefetch_i, prefetch_j):
  del prefetch_mask, prefetch_i
  k_fetch = (block_mask[i, j] != 0) * k
  return (k_fetch, prefetch_j[i, j])

def mask_map(i, j, k, block_mask, prefetch_mask, *_):
  del k, block_mask
  return (prefetch_mask[i, j], 0, 0)

def o_map(i, j, k, *_):
  del k
  return (i, j)

grid_spec = pltpu.PrefetchScalarGridSpec(
    num_scalar_prefetch=4,
    grid=(M // blk_M, N // blk_N, K // blk_K),
    in_specs=[pl.BlockSpec((blk_M, blk_K), x_map),
              pl.BlockSpec((blk_K, blk_N), y_map),
              pl.BlockSpec((1, blk_M, blk_N), mask_map)],
    out_specs=pl.BlockSpec((blk_M, blk_N), o_map),
    scratch_shapes=[pltpu.VMEM((blk_M, blk_N), dtype=jnp.float32)]
)
kernel = pl.pallas_call(
  sparse_mask_matmul,
  grid_spec=grid_spec,
  out_shape=jax.ShapeDtypeStruct((M, N), jnp.bfloat16),
)
args = (block_mask, prefetch_mask, prefetch_i, prefetch_j, X, Y, sparse_mask_data)
result = kernel(*args)

ref = mask * (X @ Y)
diff = jnp.abs(ref - result)
print('mean |result - ref|:', jnp.mean(diff))
mean |result - ref|: 1.0252e-05

现在让我们将性能与朴素的密集实现进行比较。在 TPU v5e 上,与使用下三角掩码并且仅访问一半可能的输出的理论最佳情况 2 倍相比,我们使用稀疏内核实现了大约 1.8 倍的速度提升。

我们通常预计,随着输入的增大,性能会更接近理论峰值,因为我们没有完全达到理论性能的几个主要原因是:

  • 我们跳过的计算略少于一半,因为沿对角线的块是 0 和 1 的混合,并且对于混合块,我们需要计算整个块。 随着输入量的增大,混合块的开销相对于整体计算而言会变得更小。

  • 随着输入量的增大,流水线气泡在整体运行时间中所占的比例也变小了。

n_trials = 100

pallas_impl = lambda *args: kernel(*args)
time = benchmark(pallas_impl, n_trials)(block_mask, prefetch_mask, prefetch_i, prefetch_j, X, Y, sparse_mask_data)
print("Sparse Kernel: %.3f ms (avg over %d trials)" % (time * 1000, n_trials))

ref_impl = jax.jit(lambda mask, x, y: mask * (x @ y))
time = benchmark(ref_impl, n_trials)(mask, X, Y)
print("Reference: %.3f ms (avg over %d trials)" % (time * 1000, n_trials))
Sparse Kernel: 28.648 ms (avg over 100 trials)
Reference: 49.988 ms (avg over 100 trials)