Skip to content

Commit

Permalink
Merge pull request #4 from carlthome/add-pytest
Browse files Browse the repository at this point in the history
Add unit test workflow
  • Loading branch information
SebastianLoef authored Sep 1, 2023
2 parents 3b54848 + d6b65ea commit 31eb32b
Show file tree
Hide file tree
Showing 14 changed files with 137 additions and 75 deletions.
22 changes: 22 additions & 0 deletions .github/workflows/pytest.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
on: push

jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.x"
cache: pip

- name: Install dependencies
run: make requirements

- name: Install test runner
run: pip install pytest pytest-cov

- name: Run unit tests
run: pytest --cov=src
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
[tool.isort]
profile = "black"

[tool.autoflake]
remove_all_unused_imports = true
1 change: 0 additions & 1 deletion src/architectures.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import torch
import torch.nn as nn
from torchvision.models import ResNet50_Weights, resnet50

Expand Down
14 changes: 14 additions & 0 deletions src/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from src.data.freemusicarchive import FreeMusicArchive
from src.data.gtzan import GTZAN
from src.data.magnatagatune import MagnaTagATune
from src.data.millionsongdataset import MillionSongDataset
from src.data.nsynth import NSynthInstrument, NSynthPitch

DATASETS = {
"mtat": MagnaTagATune,
"fma": FreeMusicArchive,
"gtzan": GTZAN,
"msd": MillionSongDataset,
"nsynth_instrument": NSynthInstrument,
"nsynth_pitch": NSynthPitch,
}
7 changes: 3 additions & 4 deletions src/data/test_dataset.py → src/data/clips_dataset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from typing import Tuple

import torch
import torch.nn as nn
from torch.utils.data import Dataset

from transforms import MelSpectrogram
from src.transforms import MelSpectrogram


class TestDataset(Dataset):
class ClipsDataset(Dataset):
def __init__(self, args, dataset: Dataset) -> None:
super().__init__()
self.dataset = dataset
Expand All @@ -33,7 +32,7 @@ def __len__(self) -> int:
from torch.utils.data import DataLoader

dataset = MagnaTagATune("test")
tdataset = TestDataset(dataset)
tdataset = ClipsDataset(dataset)
loader = DataLoader(tdataset, batch_size=1, shuffle=False)

for batch, label in loader:
Expand Down
6 changes: 3 additions & 3 deletions src/data/encoded_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from torch import Tensor
from torchvision.transforms import Compose

from transforms import MelSpectrogram, RandomResizedCrop
from utils import generate_encodings, get_dataset
from src.transforms import MelSpectrogram, RandomResizedCrop
from src.utils import generate_encodings


class EncodedDataset(nn.Module):
Expand All @@ -31,7 +31,7 @@ def get_integral_dataset(
MelSpectrogram(backbone_args),
]
)
dataset = get_dataset(args.dataset)(subset=subset, transforms=transforms)
dataset = DATASETS[args.dataset](subset=subset, transforms=transforms)
self.MULTILABEL = dataset.MULTILABEL
self.NUM_LABELS = dataset.NUM_LABELS
return dataset
Expand Down
1 change: 0 additions & 1 deletion src/data/nsynth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import shutil
from typing import Tuple

import pandas as pd
import requests
import torch
import torchaudio
Expand Down
24 changes: 9 additions & 15 deletions src/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
import argparse

import lightning as L
import numpy as np
import yaml
from sklearn import metrics
from sklearn.linear_model import LinearRegression, LogisticRegression, SGDRegressor
from sklearn.linear_model import SGDRegressor
from sklearn.multiclass import OneVsRestClassifier
from sklearn.multioutput import ClassifierChain
from sklearn.tree import DecisionTreeRegressor

from architectures import resnet
from modules.VICReg import VICReg
from transforms import MelSpectrogram
from utils import (
from src.architectures import resnet
from src.data import DATASETS
from src.modules.VICReg import VICReg
from src.transforms import MelSpectrogram
from src.utils import (
generate_encodings,
get_best_metric_checkpoint_path,
get_dataset,
get_epoch_checkpoint_path,
load_parameters,
)
Expand Down Expand Up @@ -91,11 +87,9 @@ def main(args):
# datasets
############################
transforms = MelSpectrogram(backbone_args)
train_dataset = get_dataset(args.train_dataset)(
subset="train", transforms=transforms
)
val_dataset = get_dataset(args.val_dataset)(subset="valid", transforms=transforms)
test_dataset = get_dataset(args.test_dataset)(subset="test", transforms=transforms)
train_dataset = DATASETS[args.dataset](subset="train", transforms=transforms)
val_dataset = DATASETS[args.val_dataset](subset="valid", transforms=transforms)
test_dataset = DATASETS[args.test_dataset](subset="test", transforms=transforms)

############################
# model
Expand Down
16 changes: 9 additions & 7 deletions src/modules/VICReg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,22 @@
from torch import Tensor
from torch.utils.data import DataLoader

from architectures import mlp
from optimizers import LARS, adjust_learning_rate, include_bias_and_norm
from transforms import AudioSplit
from utils import get_dataset, off_diagonal
from src.architectures import mlp
from src.optimizers import LARS, adjust_learning_rate, include_bias_and_norm
from src.transforms import AudioSplit
from src.utils import off_diagonal


class VICReg(L.LightningModule):
def __init__(self, args, backbone):
def __init__(self, args, dataset, backbone):
super().__init__()
self.args = args
self.num_features = int(args.projector.split("-")[-1])
self.backbone = backbone
self.projector = mlp(args.projector)
self.val_outputs = []
self.train_outputs = []
self.dataset = dataset

def internal_forward(self, x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]:
x = self.projector(self.backbone(x))
Expand Down Expand Up @@ -103,9 +104,10 @@ def configure_optimizers(self):
return optimizer

def train_dataloader(self) -> TRAIN_DATALOADERS:
dataset = get_dataset(self.args.dataset)
return DataLoader(
dataset("train", transforms=AudioSplit(self.args), mixing=self.args.mixing),
self.dataset(
"train", transforms=AudioSplit(self.args), mixing=self.args.mixing
),
batch_size=self.args.batch_size,
shuffle=True,
num_workers=self.args.num_workers,
Expand Down
10 changes: 6 additions & 4 deletions src/train_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger

from architectures import resnet
from modules.VICReg import VICReg
from utils import get_model_name, get_model_number, save_parameters
from src.architectures import resnet
from src.data import DATASETS
from src.modules.VICReg import VICReg
from src.utils import get_model_name, get_model_number, save_parameters


def get_arguments():
Expand Down Expand Up @@ -45,7 +46,8 @@ def main(args):
# model
############################
backbone = resnet(args.pretrained)
model = VICReg(args, backbone)
dataset = DATASETS[args.dataset]
model = VICReg(args, dataset, backbone)
checkpoint = ModelCheckpoint(
dirpath=f"data/models/{name}",
filename="vicreg-{epoch:02d}",
Expand Down
21 changes: 11 additions & 10 deletions src/train_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
from torchaudio_augmentations import RandomResizedCrop
from torchvision.transforms import Compose

from architectures import resnet
from data.test_dataset import TestDataset
from modules.Classifier import Classifier
from modules.VICReg import VICReg
from transforms import MelSpectrogram
from utils import (
from src.architectures import resnet
from src.data import DATASETS
from src.data.clips_dataset import ClipsDataset
from src.modules.Classifier import Classifier
from src.modules.VICReg import VICReg
from src.transforms import MelSpectrogram
from src.utils import (
class_balanced_sampler,
get_best_metric_checkpoint_path,
get_dataset,
get_epoch_checkpoint_path,
get_model_number,
load_parameters,
Expand Down Expand Up @@ -65,11 +65,11 @@ def main(args):
############################
# dataset
############################
dataset = get_dataset(args.dataset)
dataset = DATASETS[backbone_args.dataset]
train_dataset = dataset(subset="train", transforms=transforms)
val_dataset = dataset(subset="valid", transforms=transforms)
test_dataset = dataset(subset="test", transforms=None)
test_dataset = TestDataset(backbone_args, test_dataset)
test_dataset = ClipsDataset(backbone_args, test_dataset)
if args.class_balanced:
sampler = class_balanced_sampler(train_dataset)
shuffle = False
Expand Down Expand Up @@ -105,8 +105,9 @@ def main(args):
# model
############################
print(backbone_path)
dataset = DATASETS[backbone_args.dataset]
backbone_module = VICReg.load_from_checkpoint(
backbone_path, args=backbone_args, backbone=resnet()
backbone_path, args=backbone_args, dataset=dataset, backbone=resnet()
)
backbone = backbone_module.backbone.cpu()
model = Classifier(args, MULTILABELS, NUM_LABELS, backbone)
Expand Down
30 changes: 0 additions & 30 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,6 @@
import torch
from tqdm import tqdm

from data.freemusicarchive import FreeMusicArchive
from data.gtzan import GTZAN
from data.magnatagatune import MagnaTagATune
from data.millionsongdataset import MillionSongDataset
from data.nsynth import NSynthInstrument, NSynthPitch

# import wget


def generate_encodings(args, module, dataset, subset, normalize=False):
path = f"data/models/{args.name}/{args.dataset}"
Expand Down Expand Up @@ -82,28 +74,6 @@ def get_epoch_checkpoint_path(name: str, epoch: int = 0) -> str:
return d[idx]


def get_dataset(name: str):
if name == "mtat":
print("Using MagnaTagATune dataset")
return MagnaTagATune
elif name == "fma":
print("Using FreeMusicArchive dataset")
return FreeMusicArchive
elif name == "gtzan":
print("Using GTZAN dataset")
return GTZAN
elif name == "msd":
print("Using MillionSongDataset dataset")
return MillionSongDataset
elif "nsynth" in name:
if "instrument" in name:
return NSynthInstrument
elif "pitch" in name:
return NSynthPitch

raise NotImplementedError


def save_parameters(args, name):
if not os.path.exists(f"data/models/{name}"):
os.makedirs(f"data/models/{name}")
Expand Down
Empty file added tests/__init__.py
Empty file.
57 changes: 57 additions & 0 deletions tests/test_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import argparse

import lightning
import torch
from torch.utils.data import Dataset

from src.architectures import resnet
from src.modules.VICReg import VICReg


class RandomDataset(Dataset):
def __init__(self, subset, mixing, transforms) -> None:
self.transforms = transforms

def __getitem__(self, index: int):
x = torch.rand(3, 128, 128, requires_grad=True)
y = torch.rand(3, 128, 128, requires_grad=True)
z = torch.rand(1, requires_grad=True)
return (x, y), z

def __len__(self) -> int:
return 10


def test_train():
args = argparse.Namespace(
batch_size=2,
cov_coeff=1.0,
devices=1,
epochs=1,
f_max=1.0,
f_min=1.0,
hop_length=1,
mixing=False,
n_fft=1,
n_samples=2,
normalize=False,
num_workers=0,
prefetch_factor=None,
projector="2048-2-2048",
sample_rate=3,
sim_coeff=1.0,
std_coeff=1.0,
strategy="auto",
weight_decay=1e-9,
win_length=1,
base_lr=1e-3,
)

backbone = resnet(pretrained=False)
dataset = RandomDataset
model = VICReg(args=args, dataset=dataset, backbone=backbone)

trainer = lightning.Trainer(max_epochs=1)
trainer.fit(model)

assert trainer.state.finished

0 comments on commit 31eb32b

Please sign in to comment.