Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

STGNN Implmentation #9859

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open

Conversation

TKM6403
Copy link

@TKM6403 TKM6403 commented Dec 14, 2024

This pull request was created as part of the final project for the CS224W class at Stanford. This is a PyG implementation of the forecasting model in “Pre-training Enhanced Spatial-temporal Graph Neural Network for Multivariate Time Series Forecasting”. This jointly trains a Graph Structure Learner (which learns a dependency matrix between multivariate time series variables) with a STGNN (spatio-temporal graph neural network) backend, with some key modifications.

The most significant difference is that we replace their STGNN framework of the forecasting model, which was originally set to GraphWaveNet (not implemented in PyG, and their repo code does not use PyG for this), with a STGCN (spatio-temporal graph convolutional neural network, which we implement much of in PyG). This work is covered more in “Spatio-Temporal Graph Convolutional Networks: A Deep Learning Framework, and we implement it in PyG. This consists of the following modules:

Report on Medium

a) TempConv. This is a temporal convolutional layer, described in the above paper, and it does not require PyG to implement. However, it could be a good addition to torch_geometric.nn.conv.

b) SpatioTempConv. This is a spatio-temporal convolutional layer, described in the above paper, and it is built from TempConv layers and Chebyshev convolutional layers (which are already implemented in PyG). This is not currently in PyG, and would be a great addition to torch_geometric.nn.conv, as it is one of the only convolutional layers which captures both spatial and temporal components.

c) EnhancedSTGCN. This is built with SpatioTempConv layers and an MLP, and takes in input time series data, a dependency graph between time series variables, as well as encoded representations (this is why it is called enhanced, as the vanilla STGCN does not use this) of patches of the data. We generated these encoded representations through STEP (spatio-temporal enhanced pretraining), described in paper from the first paragraph. Note STEP is independent of the training of this enhanced STGCN, so the encodings are generated before training. However, we do not include STEP in our PR, since our modularization allows one to plug in encoded representations of patches from any model. To see the implementation of STEP, see code here.

d) GraphStructureLearner. Often, dependency graphs between multivariate time series variables are not available or easily inferred. This module attempts to actually learn the dependency graph as an adjacency matrix, using a kNN on the encoded patch representations (from STEP), as well as global features on the entire time series data.

e) DownstreamModel. This allows us to jointly train the GraphStructureLearner and the EnhancedSTGCN, since the (learnable) adjacency matrix from the former is directly used as input to the latter. This contains a loss contribution from both models, with the graph structure loss decaying over time. We also enable forecasting functionality: this can be used to predict the next L time steps of multivariate time series data from the previous L time steps, and the final learned adjacency matrix. This would be a good addition to torch_geometric.nn.models, being one of the few models that both enables joint modeling of spatial and temporal dependencies, and learning of dependency graphs between multiple variables when it is not already known.

@akihironitta akihironitta changed the title STGNN Implmentation for CS224w Final Project STGNN Implmentation Dec 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants