jax.experimental.sparse.sparsify

目录

jax.experimental.sparse.sparsify#

jax.experimental.sparse.sparsify(f, use_tracer=False)[source]#

实验性稀疏化转换。

示例

装饰 JAX 函数以使其与 jax.experimental.sparse.BCOO 矩阵兼容

>>> from jax.experimental import sparse
>>> @sparse.sparsify
... def f(M, v):
...   return 2 * M.T @ v
>>> M = sparse.BCOO.fromdense(jnp.arange(12).reshape(3, 4))
>>> v = jnp.array([3, 4, 2])
>>> f(M, v)
Array([ 64,  82, 100, 118], dtype=int32)