diff --git a/molpipeline/post_prediction.py b/molpipeline/post_prediction.py index f5a5881e..e0525629 100644 --- a/molpipeline/post_prediction.py +++ b/molpipeline/post_prediction.py @@ -11,7 +11,7 @@ from typing_extensions import Self from numpy import typing as npt -from sklearn.base import BaseEstimator, TransformerMixin, clone +from sklearn.base import BaseEstimator, TransformerMixin from molpipeline.abstract_pipeline_elements.core import ABCPipelineElement from molpipeline.error_handling import FilterReinserter @@ -194,15 +194,10 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: dict[str, Any] Parameters. """ + param_dict = {"wrapped_estimator": self.wrapped_estimator} if deep: - param_dict = { - "wrapped_estimator": clone(self.wrapped_estimator), - } - else: - param_dict = { - "wrapped_estimator": self.wrapped_estimator, - } - param_dict.update(self.wrapped_estimator.get_params(deep=deep)) + for key, value in self.wrapped_estimator.get_params(deep=deep).items(): + param_dict[f"wrapped_estimator__{key}"] = value return param_dict def set_params(self, **params: Any) -> Self: @@ -219,12 +214,12 @@ def set_params(self, **params: Any) -> Self: Parameters. """ param_copy = dict(params) - wrapped_estimator = param_copy.pop("wrapped_estimator") - if wrapped_estimator: - self.wrapped_estimator = wrapped_estimator - if param_copy: - if isinstance(self.wrapped_estimator, ABCPipelineElement): - self.wrapped_estimator.set_params(**param_copy) - else: - self.wrapped_estimator.set_params(**param_copy) + if "wrapped_estimator" in param_copy: + self.wrapped_estimator = param_copy.pop("wrapped_estimator") + wrapped_estimator_params = {} + for key, value in param_copy.items(): + estimator, _, param = key.partition("__") + if estimator == "wrapped_estimator": + wrapped_estimator_params[param] = value + self.wrapped_estimator.set_params(**wrapped_estimator_params) return self diff --git a/tests/test_elements/test_post_prediction.py b/tests/test_elements/test_post_prediction.py new file mode 100644 index 00000000..cbfcf901 --- /dev/null +++ b/tests/test_elements/test_post_prediction.py @@ -0,0 +1,83 @@ +"""Test the module post_prediction.py.""" + +import unittest + +import numpy as np +from sklearn.base import clone +from sklearn.decomposition import PCA +from sklearn.ensemble import RandomForestClassifier + +from molpipeline.post_prediction import PostPredictionWrapper + + +class TestPostPredictionWrapper(unittest.TestCase): + """Test the PostPredictionWrapper class.""" + + def test_get_params(self) -> None: + """Test get_params method.""" + rf = RandomForestClassifier() + rf_params = rf.get_params(deep=True) + + ppw = PostPredictionWrapper(rf) + ppw_params = ppw.get_params(deep=True) + + wrapped_params = {} + for key, value in ppw_params.items(): + first, _, rest = key.partition("__") + if first == "wrapped_estimator": + if rest == "": + self.assertIs(rf, value) + else: + wrapped_params[rest] = value + + self.assertDictEqual(rf_params, wrapped_params) + + def test_set_params(self) -> None: + """Test set_params method.""" + rf = RandomForestClassifier() + ppw = PostPredictionWrapper(rf) + + ppw.set_params(wrapped_estimator__n_estimators=10) + self.assertIsInstance(ppw.wrapped_estimator, RandomForestClassifier) + if not isinstance(ppw.wrapped_estimator, RandomForestClassifier): + raise TypeError("Wrapped estimator is not a RandomForestClassifier.") + self.assertEqual(ppw.wrapped_estimator.n_estimators, 10) + + ppw_params = ppw.get_params(deep=True) + self.assertEqual(ppw_params["wrapped_estimator__n_estimators"], 10) + + def test_fit_transform(self) -> None: + """Test fit method.""" + rng = np.random.default_rng(20240918) + features = rng.random((10, 5)) + + pca = PCA(n_components=3) + pca.fit(features) + pca_transformed = pca.transform(features) + + ppw = PostPredictionWrapper(clone(pca)) + ppw.fit(features) + ppw_transformed = ppw.transform(features) + + self.assertEqual(pca_transformed.shape, ppw_transformed.shape) + self.assertTrue(np.allclose(pca_transformed, ppw_transformed)) + + def test_inverse_transform(self) -> None: + """Test inverse_transform method.""" + rng = np.random.default_rng(20240918) + features = rng.random((10, 5)) + + pca = PCA(n_components=3) + pca.fit(features) + pca_transformed = pca.transform(features) + pca_inverse = pca.inverse_transform(pca_transformed) + + ppw = PostPredictionWrapper(clone(pca)) + ppw.fit(features) + ppw_transformed = ppw.transform(features) + ppw_inverse = ppw.inverse_transform(ppw_transformed) + + self.assertEqual(features.shape, ppw_inverse.shape) + self.assertEqual(pca_inverse.shape, ppw_inverse.shape) + + self.assertTrue(np.allclose(pca_inverse, ppw_inverse))