Skip to content

Commit

Permalink
add relevant test
Browse files Browse the repository at this point in the history
  • Loading branch information
gcroci2 committed Feb 21, 2024
1 parent 2c7f8eb commit 147ae78
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import glob
import inspect
import logging
import os
import shutil
Expand All @@ -7,6 +8,7 @@
import warnings

import h5py
import numpy as np
import pandas as pd
import pytest
import torch
Expand Down Expand Up @@ -761,6 +763,37 @@ def test_test_method_pretrained_model_on_dataset_without_target(self) -> None:
assert output.target.unique().tolist()[0] is None
assert output.loss.unique().tolist()[0] is None

def test_graph_save_and_load_model(self) -> None:
test_data_graph = "tests/data/hdf5/test.hdf5"
n = 10
features_transform = {
Nfeat.RESTYPE: {"transform": lambda x: x / 2, "standardize": True},
Nfeat.BSA: {"transform": None, "standardize": False},
}

dataset = GraphDataset(
hdf5_path=test_data_graph,
node_features=[Nfeat.RESTYPE, Nfeat.POLARITY, Nfeat.BSA],
target=targets.BINARY,
task=targets.CLASSIF,
features_transform=features_transform,
)
trainer = Trainer(NaiveNetwork, dataset)
# during the training the model is saved
trainer.train(nepoch=2, batch_size=2, filename=self.save_path)
assert trainer.features_transform == features_transform

# load the model into a new GraphDataset instance
dataset_test = GraphDataset(
hdf5_path="tests/data/hdf5/test.hdf5",
train_source=self.save_path,
)

# Check if the features_transform is correctly loaded from the saved model
assert dataset_test.features_transform[Nfeat.RESTYPE]["transform"](n) == n / 2 # the only way to test the transform in this case is to apply it
assert dataset_test.features_transform[Nfeat.RESTYPE]["standardize"] == features_transform[Nfeat.RESTYPE]["standardize"]
assert dataset_test.features_transform[Nfeat.BSA] == features_transform[Nfeat.BSA]


if __name__ == "__main__":
unittest.main()

0 comments on commit 147ae78

Please sign in to comment.