Skip to content

A collection of utility functions to work with PyTorch sparse tensors

License

Notifications You must be signed in to change notification settings

cai4cai/torchsparsegradutils

Repository files navigation

Sparsity-preserving gradient utility tools for PyTorch

Python tests License Code Style: Black

A collection of utility functions to work with PyTorch sparse tensors. This is work-in-progress, here be dragons.

Currenly available features with backprop include:

  • Memory efficient sparse mm with batch support (workaround for pytorch/pytorch#41128)
  • Sparse triangular solver with batch support (see discussion in pytorch/pytorch#87358)
  • Generic sparse linear solver (requires a non-differentiable backbone sparse solver)
  • Generic sparse linear least-squares solver (requires a non-differentiable backbone sparse linear least-squares solver)
  • Wrappers around cupy sparse solvers (see discussion in pytorch/pytorch#69538)
  • Wrappers around jax sparse solvers
  • Sparse multivariate normal distribution with sparse covariance and precision parameterisation, with reparameterised sampling (rsample)

Additional backbone solvers implemented in pytorch with no additional dependencies include:

Additional features:

  • Pairwise voxel encoder for encoding local neighbourhood relationships in a 3D spatial volume with multiple channels, into a sparse COO or CSR matrix
  • Pure PyTorch implementations of indexed multiplication operations (segment_mm and gather_mm - as provided by dgl.ops.segment_mm, pyg_lib.ops.segment_matmul, and dgl.ops.gather_mm)

Things that are missing may be listed as issues.

Installation

The provided package can be installed using:

pip install torchsparsegradutils

or

pip install git+https://github.com/cai4cai/torchsparsegradutils

Unit Tests

A number of unittests are provided, which can be run as:

python -m pytest

(Note that this also runs the tests from unittest)