A collection of utility functions to work with PyTorch sparse tensors. This is work-in-progress, here be dragons.
Currenly available features with backprop include:
- Memory efficient sparse mm with batch support (workaround for pytorch/pytorch#41128)
- Sparse triangular solver with batch support (see discussion in pytorch/pytorch#87358)
- Generic sparse linear solver (requires a non-differentiable backbone sparse solver)
- Generic sparse linear least-squares solver (requires a non-differentiable backbone sparse linear least-squares solver)
- Wrappers around cupy sparse solvers (see discussion in pytorch/pytorch#69538)
- Wrappers around jax sparse solvers
- Sparse multivariate normal distribution with sparse covariance and precision parameterisation, with reparameterised sampling (rsample)
Additional backbone solvers implemented in pytorch with no additional dependencies include:
- BICGSTAB (ported from pykrylov)
- CG (ported from cornellius-gp/linear_operator)
- LSMR (ported from pytorch-minimize)
- MINRES (ported from cornellius-gp/linear_operator)
Additional features:
- Pairwise voxel encoder for encoding local neighbourhood relationships in a 3D spatial volume with multiple channels, into a sparse COO or CSR matrix
- Pure PyTorch implementations of indexed multiplication operations (
segment_mm
andgather_mm
- as provided bydgl.ops.segment_mm
,pyg_lib.ops.segment_matmul
, anddgl.ops.gather_mm
)
Things that are missing may be listed as issues.
The provided package can be installed using:
pip install torchsparsegradutils
or
pip install git+https://github.com/cai4cai/torchsparsegradutils
A number of unittests are provided, which can be run as:
python -m pytest
(Note that this also runs the tests from unittest
)