jax.experimental.sparse 模块#

jax.experimental.sparse 模块包括 JAX 中对稀疏矩阵操作的实验性支持。它正在积极开发中,API 可能会发生变化。主要提供的接口是 BCOO 稀疏数组类型和 sparsify() 转换。

批次坐标 (BCOO) 稀疏矩阵#

目前 JAX 中可用的主要高级稀疏对象是 BCOO,即批次坐标稀疏数组,它提供与 JAX 转换兼容的压缩存储格式,特别是 JIT(例如,jax.jit())、批处理(例如,jax.vmap())和自动微分(例如,jax.grad())。

这是一个从密集数组创建稀疏数组的示例

>>> from jax.experimental import sparse
>>> import jax.numpy as jnp
>>> import numpy as np
>>> M = jnp.array([[0., 1., 0., 2.],
...                [3., 0., 0., 0.],
...                [0., 0., 4., 0.]])
>>> M_sp = sparse.BCOO.fromdense(M)
>>> M_sp
BCOO(float32[3, 4], nse=4)

使用 todense() 方法转换回密集数组

>>> M_sp.todense()
Array([[0., 1., 0., 2.],
       [3., 0., 0., 0.],
       [0., 0., 4., 0.]], dtype=float32)

BCOO 格式是标准 COO 格式的某种修改版本,密集表示可以在 dataindices 属性中看到

>>> M_sp.data  # Explicitly stored data
Array([1., 2., 3., 4.], dtype=float32)
>>> M_sp.indices # Indices of the stored data
Array([[0, 1],
       [0, 3],
       [1, 0],
       [2, 2]], dtype=int32)

BCOO 对象具有熟悉的类似数组的属性,以及特定于稀疏的属性

>>> M_sp.ndim
2
>>> M_sp.shape
(3, 4)
>>> M_sp.dtype
dtype('float32')
>>> M_sp.nse  # "number of specified elements"
4

BCOO 对象还实现了一些类似数组的方法,允许您直接在 jax 程序中使用它们。例如,这里我们计算转置的矩阵-向量积

>>> y = jnp.array([3., 6., 5.])
>>> M_sp.T @ y
Array([18.,  3., 20.,  6.], dtype=float32)
>>> M.T @ y  # Compare to dense version
Array([18.,  3., 20.,  6.], dtype=float32)

BCOO 对象旨在与 JAX 转换兼容,包括 jax.jit()jax.vmap()jax.grad() 等。例如

>>> from jax import grad, jit
>>> def f(y):
...   return (M_sp.T @ y).sum()
...
>>> jit(grad(f))(y)
Array([3., 3., 4.], dtype=float32)

但请注意,在正常情况下,jax.numpyjax.lax 函数不知道如何处理稀疏矩阵,因此尝试计算诸如 jnp.dot(M_sp.T, y) 之类的内容会导致错误(但是,请参阅下一节)。

Sparsify 转换#

JAX 稀疏实现的一个总体目标是提供一种从密集计算无缝切换到稀疏计算的方法,而无需修改密集实现。此稀疏实验通过 sparsify() 转换来实现此目标。

考虑这个函数,它从矩阵和向量输入计算更复杂的结果

>>> def f(M, v):
...   return 2 * jnp.dot(jnp.log1p(M.T), v) + 1
...
>>> f(M, y)
Array([17.635532,  5.158883, 17.09438 ,  7.591674], dtype=float32)

如果我们直接将稀疏矩阵传递给此函数,则会导致错误,因为 jnp 函数无法识别稀疏输入。但是,使用 sparsify(),我们可以获得一个接受稀疏矩阵的函数版本

>>> f_sp = sparse.sparsify(f)
>>> f_sp(M_sp, y)
Array([17.635532,  5.158883, 17.09438 ,  7.591674], dtype=float32)

sparsify() 的支持包括大量最常见的原语,包括

  • 广义(批处理)矩阵乘积和爱因斯坦求和 (dot_general_p)

  • 保留零的逐元素二元运算(例如,add_pmul_p 等)

  • 保留零的逐元素一元运算(例如,abs_pjax.lax.neg_p 等)

  • 求和归约 (reduce_sum_p)

  • 一般索引操作 (slice_plax.dynamic_slice_plax.gather_p)

  • 连接和堆叠 (concatenate_p)

  • 转置和重塑 ((transpose_preshape_psqueeze_pbroadcast_in_dim_p)

  • 一些高阶函数 (cond_pwhile_pscan_p)

  • 一些简单的 1D 卷积 (conv_general_dilated_p)

几乎任何可以降到这些受支持的原语的 jax.numpy 函数都可以在 sparsify 转换中使用,以对稀疏数组进行操作。这组原语足以支持相对复杂的稀疏工作流程,如下一节所示。

示例:稀疏逻辑回归#

作为更复杂稀疏工作流程的示例,让我们考虑在 JAX 中实现的简单逻辑回归。请注意,以下实现没有引用稀疏性

>>> import functools
>>> from sklearn.datasets import make_classification
>>> from jax.scipy import optimize
>>> def sigmoid(x):
...   return 0.5 * (jnp.tanh(x / 2) + 1)
...
>>> def y_model(params, X):
...   return sigmoid(jnp.dot(X, params[1:]) + params[0])
...
>>> def loss(params, X, y):
...   y_hat = y_model(params, X)
...   return -jnp.mean(y * jnp.log(y_hat) + (1 - y) * jnp.log(1 - y_hat))
...
>>> def fit_logreg(X, y):
...   params = jnp.zeros(X.shape[1] + 1)
...   result = optimize.minimize(functools.partial(loss, X=X, y=y),
...                              x0=params, method='BFGS')
...   return result.x
>>> X, y = make_classification(n_classes=2, random_state=1701)
>>> params_dense = fit_logreg(X, y)
>>> print(params_dense)  
[-0.7298445   0.29893667  1.0248291  -0.44436368  0.8785025  -0.7724008
 -0.62893456  0.2934014   0.82974285  0.16838408 -0.39774987 -0.5071844
  0.2028872   0.5227761  -0.3739224  -0.7104083   2.4212713   0.6310087
 -0.67060554  0.03139788 -0.05359547]

这将返回密集逻辑回归问题的最佳拟合参数。为了在稀疏数据上拟合相同的模型,我们可以应用 sparsify() 转换

>>> Xsp = sparse.BCOO.fromdense(X)  # Sparse version of the input
>>> fit_logreg_sp = sparse.sparsify(fit_logreg)  # Sparse-transformed fit function
>>> params_sparse = fit_logreg_sp(Xsp, y)
>>> print(params_sparse)  
[-0.72971725  0.29878938  1.0246326  -0.44430563  0.8784217  -0.77225566
 -0.6288222   0.29335397  0.8293481   0.16820715 -0.39764675 -0.5069753
  0.202579    0.522672   -0.3740134  -0.7102678   2.4209507   0.6310593
 -0.670236    0.03132951 -0.05356663]

稀疏 API 参考#

sparsify(f[, use_tracer])

实验性稀疏化转换。

grad(fun[, argnums, has_aux])

与稀疏感知版本的 jax.grad()

value_and_grad(fun[, argnums, has_aux])

与稀疏感知版本的 jax.value_and_grad()

empty(shape[, dtype, index_dtype, sparse_format])

创建一个空的稀疏数组。

eye(N[, M, k, dtype, index_dtype, sparse_format])

创建 2D 稀疏单位矩阵。

todense(arr)

将输入转换为密集矩阵。

random_bcoo(key, shape, *[, dtype, ...])

生成一个随机 BCOO 矩阵。

JAXSparse(args, *, shape)

高级 JAX 稀疏对象的基础类。

BCOO 数据结构#

BCOO 是 *Batched COO 格式*,是 jax.experimental.sparse 中实现的主要稀疏数据结构。其操作与 JAX 的核心转换兼容,包括批处理(例如,jax.vmap())和自动微分(例如,jax.grad())。

BCOO(args, *, shape[, indices_sorted, ...])

JAX 中实现的实验性批处理 COO 矩阵

bcoo_broadcast_in_dim(mat, *, shape, ...[, ...])

通过复制数据来扩展 BCOO 数组的大小和秩。

bcoo_concatenate(operands, *, dimension)

jax.lax.concatenate() 的稀疏实现

bcoo_dot_general(lhs, rhs, *, dimension_numbers)

一种通用收缩操作。

bcoo_dot_general_sampled(A, B, indices, *, ...)

一种在给定稀疏索引处计算输出的收缩操作。

bcoo_dynamic_slice(mat, start_indices, ...)

{func}`jax.lax.dynamic_slice` 的稀疏实现。

bcoo_extract(sparr, arr, *[, assume_unique])

根据稀疏数组的索引从稠密数组中提取值。

bcoo_fromdense(mat, *[, nse, n_batch, ...])

从稠密矩阵创建 BCOO 格式的稀疏矩阵。

bcoo_gather(operand, start_indices, ...[, ...])

lax.gather 的 BCOO 版本。

bcoo_multiply_dense(sp_mat, v)

稀疏数组和稠密数组之间的逐元素乘法。

bcoo_multiply_sparse(lhs, rhs)

两个稀疏数组的逐元素乘法。

bcoo_update_layout(mat, *[, n_batch, ...])

更新 BCOO 矩阵的存储布局(即 n_batch 和 n_dense)。

bcoo_reduce_sum(mat, *, axes)

对给定轴的数组元素求和。

bcoo_reshape(mat, *, new_sizes[, ...])

{func}`jax.lax.reshape` 的稀疏实现。

bcoo_slice(mat, *, start_indices, limit_indices)

{func}`jax.lax.slice` 的稀疏实现。

bcoo_sort_indices(mat)

对 BCOO 数组的索引进行排序。

bcoo_squeeze(arr, *, dimensions)

{func}`jax.lax.squeeze` 的稀疏实现。

bcoo_sum_duplicates(mat[, nse])

对 BCOO 数组中重复的索引求和,返回一个具有排序索引的数组。

bcoo_todense(mat)

将批处理的稀疏矩阵转换为稠密矩阵。

bcoo_transpose(mat, *, permutation)

转置 BCOO 格式的数组。

BCSR 数据结构#

BCSR批处理压缩稀疏行格式,正在开发中。它的操作与 JAX 的核心转换兼容,包括批处理 (例如 jax.vmap()) 和自动微分 (例如 jax.grad())。

BCSR(args, *, shape[, indices_sorted, ...])

在 JAX 中实现的实验性批处理 CSR 矩阵。

bcsr_dot_general(lhs, rhs, *, dimension_numbers)

一种通用收缩操作。

bcsr_extract(indices, indptr, mat)

在给定的 BCSR(索引,indptr)处从稠密矩阵中提取值。

bcsr_fromdense(mat, *[, nse, n_batch, ...])

从稠密矩阵创建 BCSR 格式的稀疏矩阵。

bcsr_todense(mat)

将批处理的稀疏矩阵转换为稠密矩阵。

其他稀疏数据结构#

其他稀疏数据结构包括 COOCSRCSC。这些是简单稀疏结构的参考实现,实现了一些核心操作。它们的操作通常与诸如 jax.grad() 等自动微分转换兼容,但不与诸如 jax.vmap() 等批处理转换兼容。

COO(args, *, shape[, rows_sorted, cols_sorted])

在 JAX 中实现的实验性 COO 矩阵。

CSC(args, *, shape)

在 JAX 中实现的实验性 CSC 矩阵;API 可能会更改。

CSR(args, *, shape)

在 JAX 中实现的实验性 CSR 矩阵。

coo_fromdense(mat, *[, nse, index_dtype])

从稠密矩阵创建 COO 格式的稀疏矩阵。

coo_matmat(mat, B, *[, transpose])

COO 稀疏矩阵和稠密矩阵的乘积。

coo_matvec(mat, v[, transpose])

COO 稀疏矩阵和稠密向量的乘积。

coo_todense(mat)

将 COO 格式的稀疏矩阵转换为稠密矩阵。

csr_fromdense(mat, *[, nse, index_dtype])

从稠密矩阵创建 CSR 格式的稀疏矩阵。

csr_matmat(mat, B, *[, transpose])

CSR 稀疏矩阵和稠密矩阵的乘积。

csr_matvec(mat, v[, transpose])

CSR 稀疏矩阵和稠密向量的乘积。

csr_todense(mat)

将 CSR 格式的稀疏矩阵转换为稠密矩阵。

jax.experimental.sparse.linalg#

稀疏线性代数例程。

spsolve(data, indices, indptr, b[, tol, reorder])

使用 QR 分解的稀疏直接求解器。

lobpcg_standard(A, X[, m, tol])

使用 LOBPCG 例程计算前 k 个标准特征值。