diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml new file mode 100644 index 000000000..51b45fd41 --- /dev/null +++ b/.github/workflows/python-package.yml @@ -0,0 +1,67 @@ +name: Build and Test + +on: + pull_request: + branches: + - main + push: + branches: + - main + schedule: + # Run the tests at 00:00 each day + - cron: "0 0 * * *" + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.8] + defaults: + run: + shell: bash -l {0} + + steps: + - uses: actions/checkout@v2 + - name: cache conda + uses: actions/cache@v2 + env: + # Increase this value to reset cache if etc/example-environment.yml has not changed + CACHE_NUMBER: 0 + with: + path: ~/conda_pkgs_dir + key: + ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{ + hashFiles('requirements.txt') }} + - uses: conda-incubator/setup-miniconda@v2 + with: + activate-environment: test + python-version: 3.8 + use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly! + - name: Set up env + run: | + conda activate test + conda install pip + - name: Cache pip + uses: actions/cache@v2 + with: + # This path is specific to Ubuntu + path: ~/.cache/pip + # Look to see if there is a cache hit for the corresponding requirements file + key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pip- + ${{ runner.os }}- + - name: pytorch + run: | + conda install -y pytorch=1.7.1 torchvision cudatoolkit=10.2 -c pytorch --update-deps + - name: Install pyG + run: | + ./pyG_install.sh cu102 + - name: Install dependencies + run: | + pip install -r requirements.txt + - name: Test with pytest + run: | + pytest -v \ No newline at end of file diff --git a/README.md b/README.md index a3c03db5f..0bcf6b53e 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ +![example workflow](https://github.com/twitter-research/graph-neural-pde/actions/workflows/python-package.yml/badge.svg) + ![Cora_animation_16](https://user-images.githubusercontent.com/5874124/143270624-265c2d01-39ca-488c-b118-b68f876dfbfa.gif) ## Introduction diff --git a/pyG_install.sh b/pyG_install.sh new file mode 100755 index 000000000..1dbc15a72 --- /dev/null +++ b/pyG_install.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash +TORCH=1.7.1 +CUDA=$1 # Supply as command line cpu or cu102 +pip install torch-scatter==2.0.5 -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html +pip install torch-sparse==0.6.8 -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html +pip install torch-cluster==1.5.8 -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html +pip install torch-spline-conv==1.2.0 -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html +pip install torch-geometric==1.6.3 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index ef532627b..2fd4f9534 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,211 +1,31 @@ # This file may be used to create an environment using: # $ conda create --name --file # platform: osx-64 -absl-py=0.9.0=pypi_0 -aiohttp=3.7.2=pypi_0 -aiohttp-cors=0.7.0=pypi_0 -aioredis=1.3.1=pypi_0 -appnope=0.1.0=py38_1001 -argon2-cffi=20.1.0=py38haf1e3a3_1 -ase=3.20.1=pypi_0 -async-timeout=3.0.1=pypi_0 -async_generator=1.10=py_0 -attrs=20.2.0=py_0 -ax-platform=0.1.18=pypi_0 -backcall=0.2.0=py_0 -beautifulsoup4=4.9.3=pypi_0 -blas=1.0=mkl -bleach=3.2.1=py_0 -blessings=1.7=pypi_0 -boltons=20.2.1=pypi_0 -botorch=0.3.2=pypi_0 -ca-certificates=2020.10.14=0 -cachetools=4.1.0=pypi_0 -certifi=2020.6.20=pyhd3eb1b0_3 -cffi=1.14.3=py38hed5b41f_0 -chardet=3.0.4=pypi_0 -click=7.1.2=pypi_0 -colorama=0.4.4=pypi_0 -colorful=0.5.4=pypi_0 -cycler=0.10.0=py38_0 -dbus=1.13.18=h18a8e69_0 -decorator=4.4.2=pypi_0 -defusedxml=0.6.0=py_0 -dgl=0.4.3.post2=pypi_0 -entrypoints=0.3=py38_0 -et-xmlfile=1.0.1=pypi_0 -expat=2.2.10=hb1e8313_2 -filelock=3.0.12=pypi_0 -freetype=2.10.2=ha233b18_0 -future=0.18.2=pypi_0 -gettext=0.19.8.1=hb0f4f8b_2 -glib=2.66.1=h9bbe63b_0 -google=3.0.0=pypi_0 -google-api-core=1.23.0=pypi_0 -google-auth=1.23.0=pypi_0 -google-auth-oauthlib=0.4.1=pypi_0 -googleapis-common-protos=1.52.0=pypi_0 -googledrivedownloader=0.4=pypi_0 -gpustat=0.6.0=pypi_0 -gputil=1.4.0=pypi_0 -gpytorch=1.2.1=pypi_0 -grpcio=1.30.0=pypi_0 -h5py=2.10.0=pypi_0 -hiredis=1.1.0=pypi_0 -icu=58.2=h0a44026_3 -idna=2.10=pypi_0 -importlib-metadata=2.0.0=py_1 -importlib_metadata=2.0.0=1 -iniconfig=1.1.1=py_0 -intel-openmp=2019.4=233 -ipykernel=5.3.4=py38h5ca1d4c_0 -ipython=7.18.1=py38h5ca1d4c_0 -ipython_genutils=0.2.0=py38_0 -ipywidgets=7.5.1=py_1 -isodate=0.6.0=pypi_0 -jdcal=1.4.1=pypi_0 -jedi=0.17.2=py38_0 -jinja2=2.11.2=py_0 -joblib=0.15.1=py_0 -jpeg=9b=he5867d9_2 -jsonschema=3.2.0=py_2 -jupyter=1.0.0=py38_7 -jupyter_client=6.1.7=py_0 -jupyter_console=6.2.0=py_0 -jupyter_core=4.6.3=py38_0 -jupyterlab_pygments=0.1.2=py_0 -kiwisolver=1.2.0=py38h04f5b5a_0 -libcxx=10.0.0=1 -libedit=3.1.20191231=haf1e3a3_0 -libffi=3.3=h0a44026_1 -libgfortran=3.0.1=h93005f0_2 -libiconv=1.16=h1de35cc_0 -libpng=1.6.37=ha441bb4_0 -libprotobuf=3.13.0.1=hab81aa3_0 -libsodium=1.0.18=h1de35cc_0 -libtiff=4.1.0=hcb84e12_1 -littleutils=0.2.2=pypi_0 -llvm-openmp=10.0.0=h28b9765_0 -llvmlite=0.34.0=pypi_0 -lz4-c=1.9.2=h0a44026_0 -markdown=3.2.2=pypi_0 -markupsafe=1.1.1=py38h1de35cc_1 -matplotlib=3.2.2=0 -matplotlib-base=3.2.2=py38h5670ca0_0 -mistune=0.8.4=py38h1de35cc_1001 -mkl=2019.4=233 -mkl-service=2.3.0=py38hfbe908c_0 -mkl_fft=1.1.0=py38hc64f4ea_0 -mkl_random=1.1.1=py38h959d312_0 -more-itertools=8.5.0=py_0 -msgpack=1.0.0=pypi_0 -multidict=5.0.0=pypi_0 -nbclient=0.5.1=py_0 -nbconvert=6.0.7=py38_0 -nbformat=5.0.8=py_0 -ncurses=6.2=h0a44026_1 -nest-asyncio=1.4.1=py_0 -networkx=2.4=pypi_0 -ninja=1.9.0=py38h04f5b5a_0 -notebook=6.1.4=py38_0 -numba=0.51.2=pypi_0 -numpy=1.19.3=pypi_0 -numpy-base=1.18.5=py38h3304bdc_0 -nvidia-ml-py3=7.352.0=pypi_0 -oauthlib=3.1.0=pypi_0 -ogb=1.2.4=pypi_0 -olefile=0.46=py_0 -opencensus=0.7.11=pypi_0 -opencensus-context=0.1.2=pypi_0 -openpyxl=3.0.5=pypi_0 -openssl=1.1.1h=haf1e3a3_0 -outdated=0.2.0=pypi_0 -packaging=20.4=py_0 -pandas=1.0.5=py38h959d312_0 -pandoc=2.11=h0dc7051_0 -pandocfilters=1.4.2=py38_1 -parso=0.7.0=py_0 -pcre=8.44=hb1e8313_0 -pexpect=4.8.0=py38_1 -pickleshare=0.7.5=py38_1001 -pillow=7.1.2=py38h4655f20_0 -pip=20.1.1=py38_1 -plotly=4.12.0=pypi_0 -pluggy=0.13.1=py38_0 -prometheus_client=0.8.0=py_0 -prompt-toolkit=3.0.8=py_0 -prompt_toolkit=3.0.8=0 -protobuf=3.12.2=pypi_0 -psutil=5.7.3=pypi_0 -ptyprocess=0.6.0=py38_0 -py=1.9.0=py_0 -py-spy=0.3.3=pypi_0 -pyasn1=0.4.8=pypi_0 -pyasn1-modules=0.2.8=pypi_0 -pycparser=2.20=py_2 -pygments=2.7.2=pyhd3eb1b0_0 -pykeops=1.4.1=pypi_0 -pyparsing=2.4.7=py_0 -pyqt=5.9.2=py38h655552a_2 -pyrsistent=0.17.3=py38haf1e3a3_0 -pytest=6.1.1=py38_0 -python=3.8.3=h26836e1_1 -python-dateutil=2.8.1=py_0 -pytorch-lightning=0.8.1=pypi_0 -pytz=2020.1=py_0 -pyzmq=19.0.2=py38hb1e8313_1 -qt=5.9.7=h468cd18_1 -qtconsole=4.7.7=py_0 -qtpy=1.9.0=py_0 -ray=1.0.0=pypi_0 -rdflib=5.0.0=pypi_0 -readline=8.0=h1de35cc_0 -redis=3.4.1=pypi_0 -requests=2.24.0=pypi_0 -requests-oauthlib=1.3.0=pypi_0 -retrying=1.3.3=pypi_0 -rsa=4.6=pypi_0 -scikit-learn=0.23.1=py38h603561c_0 -scipy=1.5.3=pypi_0 -send2trash=1.5.0=py38_0 -setuptools=47.3.1=py38_0 -sip=4.19.8=py38h0a44026_0 -six=1.15.0=py_0 -soupsieve=2.0.1=pypi_0 -sqlalchemy=1.3.20=pypi_0 -sqlite=3.32.3=hffcf06c_0 -tabulate=0.8.7=pypi_0 -tensorboard=2.2.2=pypi_0 -tensorboard-plugin-wit=1.7.0=pypi_0 -tensorboardx=2.1=pypi_0 -terminado=0.9.1=py38_0 -testpath=0.4.4=py_0 -threadpoolctl=2.1.0=pyh5ca1d4c_0 -tk=8.6.10=hb0a8c7a_0 -toml=0.10.1=py_0 -torch=1.6.0=pypi_0 -torch-geometric=1.6.1=pypi_0 -torch-scatter=2.0.5=pypi_0 -torch-sparse=0.6.7=pypi_0 -torchdiffeq=0.1.1=pypi_0 -torchsde=0.2.1=pypi_0 -torchsummary=1.5.1=pypi_0 -torchvision=0.7.0=pypi_0 -tornado=6.0.4=py38h1de35cc_1 -tqdm=4.46.1=pypi_0 -traitlets=5.0.5=py_0 -trampoline=0.1.2=pypi_0 -typeguard=2.10.0=pypi_0 -typing-extensions=3.7.4.3=pypi_0 -urllib3=1.25.9=pypi_0 -wcwidth=0.2.5=py_0 -webencodings=0.5.1=py38_1 -werkzeug=1.0.1=pypi_0 -wheel=0.34.2=py38_0 -widgetsnbextension=3.5.1=py38_0 -xz=5.2.5=h1de35cc_0 -yarl=1.6.2=pypi_0 -zeromq=4.3.3=hb1e8313_3 -zipp=3.4.0=pyhd3eb1b0_0 -zlib=1.2.11=h1de35cc_3 -zstd=1.4.4=h1990bb4_3 +h5py==2.10.0 +jupyter==1.0.0 +matplotlib==3.2.2 +networkx==2.4 +numba==0.51.2 +numpy==1.19.3 +ogb==1.2.4 +pandas==1.0.5 +pip==20.1.1 +pykeops==1.4.1 +pytest==6.1.1 +ray==1.0.0 +scikit-learn==0.23.1 +scipy==1.5.3 +six==1.15.0 +tqdm==4.46.1 +torchdiffeq==0.1.1 +tabulate==0.8.7 + +# reliant on cuda version +# torch==1.6.0=pypi_0 +# torch-geometric=1.6.1=pypi_0 +# torch-scatter=2.0.5=pypi_0 +# torch-sparse=0.6.7=pypi_0 +# torchdiffeq=0.1.1=pypi_0 +# torchsde=0.2.1=pypi_0 +# torchsummary=1.5.1=pypi_0 +# torchvision=0.7.0=pypi_0 \ No newline at end of file diff --git a/test/test_ICML_gnn.py b/test/test_ICML_gnn.py index ab8ea5400..2c8e76f83 100644 --- a/test/test_ICML_gnn.py +++ b/test/test_ICML_gnn.py @@ -3,17 +3,21 @@ """ Test attention """ +import os +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src'))) + import unittest import torch from torch import tensor from torch import nn -from CGNN import gcn_norm_fill_val, coo2tensor, train_ray -from data import get_dataset from torch_geometric.nn.conv.gcn_conv import gcn_norm from torch_geometric.utils.convert import to_scipy_sparse_matrix from ray.tune.utils import diagnose_serialization from functools import partial -import os + +from CGNN import gcn_norm_fill_val, coo2tensor, train_ray +from data import get_dataset from test_params import OPT diff --git a/test/test_attention.py b/test/test_attention.py index 7284090a8..f5a246b38 100644 --- a/test/test_attention.py +++ b/test/test_attention.py @@ -3,10 +3,16 @@ """ Test attention """ +# needed for CI/CD +import os +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src'))) + import unittest import torch from torch import tensor from torch import nn + from function_GAT_attention import SpGraphAttentionLayer, ODEFuncAtt from torch_geometric.utils import softmax, to_dense_adj from data import get_dataset diff --git a/test/test_attention_ode_block.py b/test/test_attention_ode_block.py index 446bd0b92..6bf62b866 100644 --- a/test/test_attention_ode_block.py +++ b/test/test_attention_ode_block.py @@ -1,10 +1,15 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +# needed for CI/CD +import os +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src'))) import unittest import torch from torch import tensor from torch import nn + from data import get_dataset from function_laplacian_diffusion import LaplacianODEFunc from GNN import GNN diff --git a/test/test_block_mixed.py b/test/test_block_mixed.py index 2b952efe1..5ce27bb3e 100644 --- a/test/test_block_mixed.py +++ b/test/test_block_mixed.py @@ -1,17 +1,22 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +# needed for CI/CD +import os +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src'))) import unittest import torch from torch import tensor from torch import nn +import numpy as np + from data import get_dataset from function_laplacian_diffusion import LaplacianODEFunc from GNN import GNN from block_mixed import MixedODEblock from torch_geometric.data import Data from torch_geometric.utils import to_dense_adj -import numpy as np from test_params import OPT diff --git a/test/test_early_stop.py b/test/test_early_stop.py index 00c6ccbf4..f266a3aff 100644 --- a/test/test_early_stop.py +++ b/test/test_early_stop.py @@ -3,10 +3,15 @@ """ Test early stop """ +import os +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src'))) + import unittest import torch from torch import tensor from torch import nn + from data import get_dataset from function_laplacian_diffusion import LaplacianODEFunc from GNN_early import GNNEarly diff --git a/test/test_function_laplacian_diffusion.py b/test/test_function_laplacian_diffusion.py index 1b3edf6bd..b6a1d77b6 100644 --- a/test/test_function_laplacian_diffusion.py +++ b/test/test_function_laplacian_diffusion.py @@ -1,18 +1,22 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +import os +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src'))) import unittest import torch from torch import tensor from torch import nn +import numpy as np +from sklearn.preprocessing import normalize + from data import get_dataset from function_laplacian_diffusion import LaplacianODEFunc from GNN import GNN from block_constant import ConstantODEblock from torch_geometric.data import Data from torch_geometric.utils import to_dense_adj -import numpy as np -from sklearn.preprocessing import normalize from utils import get_rw_adj, get_sym_adj from test_params import OPT diff --git a/test/test_gnn.py b/test/test_gnn.py index 81ae14829..b20e34b4d 100644 --- a/test/test_gnn.py +++ b/test/test_gnn.py @@ -3,10 +3,15 @@ """ Test the GNN class """ +import os +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src'))) + import unittest import torch from torch import tensor from torch import nn + from data import get_dataset from function_laplacian_diffusion import LaplacianODEFunc from block_constant import ConstantODEblock diff --git a/test/test_transformer_attention.py b/test/test_transformer_attention.py index c2909258b..5770e07a3 100644 --- a/test/test_transformer_attention.py +++ b/test/test_transformer_attention.py @@ -3,13 +3,18 @@ """ Test attention """ +import os +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src'))) + import unittest import torch from torch import tensor from torch import nn import torch_sparse -from function_transformer_attention import SpGraphTransAttentionLayer, ODEFuncTransformerAtt from torch_geometric.utils import softmax, to_dense_adj + +from function_transformer_attention import SpGraphTransAttentionLayer, ODEFuncTransformerAtt from data import get_dataset from test_params import OPT from utils import ROOT_DIR diff --git a/test/test_utils.py b/test/test_utils.py index 84885072d..88d114b68 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,5 +1,9 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +# needed for CI/CD +import os +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src'))) import unittest import torch