diff --git a/tests/test_predict.py b/tests/test_predict.py index e2e8f2ba2..2c26a4a3e 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -6,7 +6,9 @@ import pytest +from clinicadl.metrics.utils import get_metrics from clinicadl.predict.predict_manager import PredictManager +from clinicadl.predict.utils import get_prediction from .testing_tools import compare_folders, modify_maps @@ -115,10 +117,18 @@ def test_predict(cmdopt, tmp_path, test_name): predict_manager.predict() for mode in modes: - predict_manager.maps_manager.get_prediction(data_group="test-RANDOM", mode=mode) + get_prediction( + predict_manager.maps_manager.maps_path, + predict_manager.maps_manager.split_name, + data_group="test-RANDOM", + mode=mode, + ) if use_labels: - predict_manager.maps_manager.get_metrics( - data_group="test-RANDOM", mode=mode + get_metrics( + predict_manager.maps_manager.maps_path, + predict_manager.maps_manager.split_name, + data_group="test-RANDOM", + mode=mode, ) assert compare_folders(