diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 695b7844..e6ea044b 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -1,4 +1,5 @@ import glob +import inspect import logging import os import shutil @@ -7,6 +8,7 @@ import warnings import h5py +import numpy as np import pandas as pd import pytest import torch @@ -761,6 +763,37 @@ def test_test_method_pretrained_model_on_dataset_without_target(self) -> None: assert output.target.unique().tolist()[0] is None assert output.loss.unique().tolist()[0] is None + def test_graph_save_and_load_model(self) -> None: + test_data_graph = "tests/data/hdf5/test.hdf5" + n = 10 + features_transform = { + Nfeat.RESTYPE: {"transform": lambda x: x / 2, "standardize": True}, + Nfeat.BSA: {"transform": None, "standardize": False}, + } + + dataset = GraphDataset( + hdf5_path=test_data_graph, + node_features=[Nfeat.RESTYPE, Nfeat.POLARITY, Nfeat.BSA], + target=targets.BINARY, + task=targets.CLASSIF, + features_transform=features_transform, + ) + trainer = Trainer(NaiveNetwork, dataset) + # during the training the model is saved + trainer.train(nepoch=2, batch_size=2, filename=self.save_path) + assert trainer.features_transform == features_transform + + # load the model into a new GraphDataset instance + dataset_test = GraphDataset( + hdf5_path="tests/data/hdf5/test.hdf5", + train_source=self.save_path, + ) + + # Check if the features_transform is correctly loaded from the saved model + assert dataset_test.features_transform[Nfeat.RESTYPE]["transform"](n) == n / 2 # the only way to test the transform in this case is to apply it + assert dataset_test.features_transform[Nfeat.RESTYPE]["standardize"] == features_transform[Nfeat.RESTYPE]["standardize"] + assert dataset_test.features_transform[Nfeat.BSA] == features_transform[Nfeat.BSA] + if __name__ == "__main__": unittest.main()