jax.experimental.sparse
模块#
注意
jax.experimental.sparse
中的方法是实验性的参考实现,不建议在对性能要求很高的应用中使用。
jax.experimental.sparse
模块包括对 JAX 中稀疏矩阵操作的实验性支持。它正在积极开发中,API 可能会发生变化。主要提供的接口是 BCOO
稀疏数组类型和 sparsify()
转换。
批量坐标 (BCOO) 稀疏矩阵#
目前在 JAX 中可用的主要高级稀疏对象是 BCOO
,或批量坐标稀疏数组,它提供了一种压缩存储格式,与 JAX 转换兼容,特别是 JIT (例如 jax.jit()
)、批处理 (例如 jax.vmap()
) 和自动微分 (例如 jax.grad()
)。
这是一个从密集数组创建稀疏数组的示例
>>> from jax.experimental import sparse
>>> import jax.numpy as jnp
>>> import numpy as np
>>> M = jnp.array([[0., 1., 0., 2.],
... [3., 0., 0., 0.],
... [0., 0., 4., 0.]])
>>> M_sp = sparse.BCOO.fromdense(M)
>>> M_sp
BCOO(float32[3, 4], nse=4)
使用 todense()
方法转换回密集数组
>>> M_sp.todense()
Array([[0., 1., 0., 2.],
[3., 0., 0., 0.],
[0., 0., 4., 0.]], dtype=float32)
BCOO 格式是标准 COO 格式的一个修改版本,可以在 data
和 indices
属性中看到密集的表示形式
>>> M_sp.data # Explicitly stored data
Array([1., 2., 3., 4.], dtype=float32)
>>> M_sp.indices # Indices of the stored data
Array([[0, 1],
[0, 3],
[1, 0],
[2, 2]], dtype=int32)
BCOO 对象具有熟悉的类似数组的属性,以及特定于稀疏的属性
>>> M_sp.ndim
2
>>> M_sp.shape
(3, 4)
>>> M_sp.dtype
dtype('float32')
>>> M_sp.nse # "number of specified elements"
4
BCOO 对象还实现了一些类似数组的方法,允许您在 jax 程序中直接使用它们。例如,这里我们计算转置矩阵向量积
>>> y = jnp.array([3., 6., 5.])
>>> M_sp.T @ y
Array([18., 3., 20., 6.], dtype=float32)
>>> M.T @ y # Compare to dense version
Array([18., 3., 20., 6.], dtype=float32)
BCOO 对象旨在与 JAX 转换兼容,包括 jax.jit()
、 jax.vmap()
、jax.grad()
等。例如
>>> from jax import grad, jit
>>> def f(y):
... return (M_sp.T @ y).sum()
...
>>> jit(grad(f))(y)
Array([3., 3., 4.], dtype=float32)
但请注意,在正常情况下,jax.numpy
和 jax.lax
函数不知道如何处理稀疏矩阵,因此尝试计算诸如 jnp.dot(M_sp.T, y)
之类的操作会导致错误(但是,请参阅下一节)。
稀疏化转换#
JAX 稀疏实现的一个总体目标是提供一种从密集计算无缝切换到稀疏计算的方法,而无需修改密集实现。此稀疏实验通过 sparsify()
转换来实现此目标。
考虑以下函数,它从矩阵和向量输入计算一个更复杂的结果
>>> def f(M, v):
... return 2 * jnp.dot(jnp.log1p(M.T), v) + 1
...
>>> f(M, y)
Array([17.635532, 5.158883, 17.09438 , 7.591674], dtype=float32)
如果我们直接将稀疏矩阵传递给此函数,则会导致错误,因为 jnp
函数无法识别稀疏输入。但是,使用 sparsify()
,我们可以获得此函数的一个版本,该版本可以接受稀疏矩阵
>>> f_sp = sparse.sparsify(f)
>>> f_sp(M_sp, y)
Array([17.635532, 5.158883, 17.09438 , 7.591674], dtype=float32)
对 sparsify()
的支持包括大量最常见的原语,包括
广义(批量)矩阵乘积和爱因斯坦求和 (
dot_general_p
)保留零的逐元素二元运算(例如
add_p
、mul_p
等)保留零的逐元素一元运算(例如
abs_p
、jax.lax.neg_p
等)求和归约 (
reduce_sum_p
)通用索引操作 (
slice_p
、lax.dynamic_slice_p、lax.gather_p)连接和堆叠 (
concatenate_p
)转置和重塑 ((
transpose_p
、reshape_p
、squeeze_p
、broadcast_in_dim_p
)一些高阶函数 (
cond_p
、while_p
、scan_p
)一些简单的一维卷积 (
conv_general_dilated_p
)
几乎任何降低到这些支持的原语的 jax.numpy
函数都可以在 sparsify 转换中使用,以对稀疏数组进行操作。这组原语足以支持相对复杂的稀疏工作流程,如下一节所示。
示例:稀疏逻辑回归#
作为更复杂的稀疏工作流程的示例,让我们考虑一个在 JAX 中实现的简单逻辑回归。请注意,以下实现没有对稀疏性的引用
>>> import functools
>>> from sklearn.datasets import make_classification
>>> from jax.scipy import optimize
>>> def sigmoid(x):
... return 0.5 * (jnp.tanh(x / 2) + 1)
...
>>> def y_model(params, X):
... return sigmoid(jnp.dot(X, params[1:]) + params[0])
...
>>> def loss(params, X, y):
... y_hat = y_model(params, X)
... return -jnp.mean(y * jnp.log(y_hat) + (1 - y) * jnp.log(1 - y_hat))
...
>>> def fit_logreg(X, y):
... params = jnp.zeros(X.shape[1] + 1)
... result = optimize.minimize(functools.partial(loss, X=X, y=y),
... x0=params, method='BFGS')
... return result.x
>>> X, y = make_classification(n_classes=2, random_state=1701)
>>> params_dense = fit_logreg(X, y)
>>> print(params_dense)
[-0.7298445 0.29893667 1.0248291 -0.44436368 0.8785025 -0.7724008
-0.62893456 0.2934014 0.82974285 0.16838408 -0.39774987 -0.5071844
0.2028872 0.5227761 -0.3739224 -0.7104083 2.4212713 0.6310087
-0.67060554 0.03139788 -0.05359547]
这将返回密集逻辑回归问题的最佳拟合参数。为了在稀疏数据上拟合相同的模型,我们可以应用 sparsify()
转换
>>> Xsp = sparse.BCOO.fromdense(X) # Sparse version of the input
>>> fit_logreg_sp = sparse.sparsify(fit_logreg) # Sparse-transformed fit function
>>> params_sparse = fit_logreg_sp(Xsp, y)
>>> print(params_sparse)
[-0.72971725 0.29878938 1.0246326 -0.44430563 0.8784217 -0.77225566
-0.6288222 0.29335397 0.8293481 0.16820715 -0.39764675 -0.5069753
0.202579 0.522672 -0.3740134 -0.7102678 2.4209507 0.6310593
-0.670236 0.03132951 -0.05356663]
稀疏 API 参考#
|
实验性稀疏化转换。 |
|
|
|
|
|
创建一个空的稀疏数组。 |
|
创建 2D 稀疏单位矩阵。 |
|
将输入转换为密集矩阵。 |
|
生成一个随机的 BCOO 矩阵。 |
|
高级 JAX 稀疏对象的基础类。 |
BCOO 数据结构#
BCOO
是批处理 COO 格式,是 jax.experimental.sparse
中实现的主要稀疏数据结构。它的操作与 JAX 的核心转换兼容,包括批处理(例如 jax.vmap()
)和自动微分(例如 jax.grad()
)。
|
在 JAX 中实现的实验性批处理 COO 矩阵 |
|
通过复制数据来扩展 BCOO 数组的大小和秩。 |
|
|
|
一种通用的收缩操作。 |
|
在给定稀疏索引处计算输出的收缩操作。 |
|
{func}`jax.lax.dynamic_slice` 的稀疏实现。 |
|
根据稀疏数组的索引从稠密数组中提取值。 |
|
从稠密矩阵创建 BCOO 格式的稀疏矩阵。 |
|
lax.gather 的 BCOO 版本。 |
|
稀疏数组和稠密数组之间的元素级乘法。 |
|
两个稀疏数组的元素级乘法。 |
|
更新 BCOO 矩阵的存储布局(即 n_batch & n_dense)。 |
|
在给定轴上求数组元素的和。 |
|
{func}`jax.lax.reshape` 的稀疏实现。 |
|
{func}`jax.lax.slice` 的稀疏实现。 |
|
对 BCOO 数组的索引进行排序。 |
|
{func}`jax.lax.squeeze` 的稀疏实现。 |
|
对 BCOO 数组中的重复索引求和,返回一个具有排序索引的数组。 |
|
将批处理稀疏矩阵转换为稠密矩阵。 |
|
转置 BCOO 格式的数组。 |
BCSR 数据结构#
BCSR
是批处理压缩稀疏行格式,目前正在开发中。它的操作与 JAX 的核心转换兼容,包括批处理(例如 jax.vmap()
)和自动微分(例如 jax.grad()
)。
|
在 JAX 中实现的实验性批处理 CSR 矩阵。 |
|
一种通用的收缩操作。 |
|
从给定 BCSR(索引,indptr)的稠密矩阵中提取值。 |
|
从稠密矩阵创建 BCSR 格式的稀疏矩阵。 |
|
将批处理稀疏矩阵转换为稠密矩阵。 |
其他稀疏数据结构#
其他稀疏数据结构包括 COO
、CSR
和 CSC
。这些是简单稀疏结构的参考实现,其中实现了一些核心操作。它们的操作通常与自动微分转换(如 jax.grad()
)兼容,但不与批处理转换(如 jax.vmap()
)兼容。
|
在 JAX 中实现的实验性 COO 矩阵。 |
|
在 JAX 中实现的实验性 CSC 矩阵;API 可能会发生变化。 |
|
在 JAX 中实现的实验性 CSR 矩阵。 |
|
从稠密矩阵创建 COO 格式的稀疏矩阵。 |
|
COO 稀疏矩阵和稠密矩阵的乘积。 |
|
COO 稀疏矩阵和稠密向量的乘积。 |
|
将 COO 格式的稀疏矩阵转换为稠密矩阵。 |
|
从稠密矩阵创建 CSR 格式的稀疏矩阵。 |
|
CSR 稀疏矩阵和稠密矩阵的乘积。 |
|
CSR稀疏矩阵与稠密向量的乘积。 |
|
将 CSR 格式的稀疏矩阵转换为稠密矩阵。 |
jax.experimental.sparse.linalg
#
稀疏线性代数例程。
|
使用 QR 分解的稀疏直接求解器。 |
|
使用 LOBPCG 例程计算前 k 个标准特征值。 |