Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

81 postpredictionwrapper handle set params without wrapped estimator #82

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 12 additions & 17 deletions molpipeline/post_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing_extensions import Self

from numpy import typing as npt
from sklearn.base import BaseEstimator, TransformerMixin, clone
from sklearn.base import BaseEstimator, TransformerMixin

from molpipeline.abstract_pipeline_elements.core import ABCPipelineElement
from molpipeline.error_handling import FilterReinserter
Expand Down Expand Up @@ -194,15 +194,10 @@ def get_params(self, deep: bool = True) -> dict[str, Any]:
dict[str, Any]
Parameters.
"""
param_dict = {"wrapped_estimator": self.wrapped_estimator}
if deep:
param_dict = {
"wrapped_estimator": clone(self.wrapped_estimator),
}
else:
param_dict = {
"wrapped_estimator": self.wrapped_estimator,
}
param_dict.update(self.wrapped_estimator.get_params(deep=deep))
for key, value in self.wrapped_estimator.get_params(deep=deep).items():
param_dict[f"wrapped_estimator__{key}"] = value
return param_dict

def set_params(self, **params: Any) -> Self:
Expand All @@ -219,12 +214,12 @@ def set_params(self, **params: Any) -> Self:
Parameters.
"""
param_copy = dict(params)
wrapped_estimator = param_copy.pop("wrapped_estimator")
if wrapped_estimator:
self.wrapped_estimator = wrapped_estimator
if param_copy:
if isinstance(self.wrapped_estimator, ABCPipelineElement):
self.wrapped_estimator.set_params(**param_copy)
else:
self.wrapped_estimator.set_params(**param_copy)
if "wrapped_estimator" in param_copy:
self.wrapped_estimator = param_copy.pop("wrapped_estimator")
wrapped_estimator_params = {}
for key, value in param_copy.items():
estimator, _, param = key.partition("__")
if estimator == "wrapped_estimator":
wrapped_estimator_params[param] = value
self.wrapped_estimator.set_params(**wrapped_estimator_params)
return self
82 changes: 82 additions & 0 deletions tests/test_elements/test_post_prediction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Test the module post_prediction.py."""

import unittest

import numpy as np
from sklearn.base import clone
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier

from molpipeline.post_prediction import PostPredictionWrapper


class TestPostPredictionWrapper(unittest.TestCase):
"""Test the PostPredictionWrapper class."""

def test_get_params(self) -> None:
"""Test get_params method."""
rf = RandomForestClassifier()
rf_params = rf.get_params(deep=True)

ppw = PostPredictionWrapper(rf)
ppw_params = ppw.get_params(deep=True)

wrapped_params = {}
for key, value in ppw_params.items():
first, _, rest = key.partition("__")
if first == "wrapped_estimator":
if rest == "":
self.assertEqual(rf, value)
else:
wrapped_params[rest] = value

self.assertDictEqual(rf_params, wrapped_params)

def test_set_params(self) -> None:
"""Test set_params method."""
rf = RandomForestClassifier()
ppw = PostPredictionWrapper(rf)

ppw.set_params(wrapped_estimator__n_estimators=10)
if not isinstance(ppw.wrapped_estimator, RandomForestClassifier):
raise TypeError("Wrapped estimator is not a RandomForestClassifier.")
self.assertEqual(ppw.wrapped_estimator.n_estimators, 10)

ppw_params = ppw.get_params(deep=True)
self.assertEqual(ppw_params["wrapped_estimator__n_estimators"], 10)

def test_fit_transform(self) -> None:
"""Test fit method."""
rng = np.random.default_rng(20240918)
features = rng.random((100, 10))

pca = PCA(n_components=3)
pca.fit(features)
pca_transformed = pca.transform(features)

ppw = PostPredictionWrapper(clone(pca))
ppw.fit(features)
ppw_transformed = ppw.transform(features)

self.assertEqual(pca_transformed.shape, ppw_transformed.shape)
self.assertTrue(np.allclose(pca_transformed, ppw_transformed))

def test_inverse_transform(self) -> None:
"""Test inverse_transform method."""
rng = np.random.default_rng(20240918)
features = rng.random((5, 10))

pca = PCA(n_components=3)
pca.fit(features)
pca_transformed = pca.transform(features)
pca_inverse = pca.inverse_transform(pca_transformed)

ppw = PostPredictionWrapper(clone(pca))
ppw.fit(features)
ppw_transformed = ppw.transform(features)
ppw_inverse = ppw.inverse_transform(ppw_transformed)

self.assertEqual(features.shape, ppw_inverse.shape)
self.assertEqual(pca_inverse.shape, ppw_inverse.shape)

self.assertTrue(np.allclose(pca_inverse, ppw_inverse))
Loading