diff --git a/src/raman_fitting/deconvolution_models/base_model.py b/src/raman_fitting/deconvolution_models/base_model.py index a0d140c..5a4a5e3 100644 --- a/src/raman_fitting/deconvolution_models/base_model.py +++ b/src/raman_fitting/deconvolution_models/base_model.py @@ -1,25 +1,20 @@ """ The members of the validated collection of BasePeaks are assembled here into fitting Models""" import logging - +from typing import Optional, Dict from warnings import warn -from typing import List, Optional, Dict - -from lmfit.models import Model, GaussianModel - -from raman_fitting.deconvolution_models.peak_validation import PeakModelValidator -from .base_peak import BasePeak, get_default_peaks +from lmfit.models import Model from pydantic import ( BaseModel, - PositiveInt, Field, ConfigDict, model_validator, - ValidationInfo, - field_validator, - ValidationError, ) +from .base_peak import BasePeak, get_default_peaks +from .lmfit import construct_lmfit_model_from_components +from ..config.filepath_helper import load_default_peak_toml_files + logger = logging.getLogger(__name__) SUBSTRATE_PEAK = "Si1_peak" @@ -31,19 +26,7 @@ class BaseModelWarning(UserWarning): pass -def construct_lmfit_model_from_base_peaks(base_peaks: List[BasePeak]) -> Model: - """ - Construct the lmfit model from a collection of (known) peaks - """ - if not base_peaks: - raise ValueError("No peaks given to construct lmfit model from.") - lmfit_composite_model = sum( - map(lambda x: x.lmfit_model, base_peaks), base_peaks.pop().lmfit_model - ) - return lmfit_composite_model - - -class BaseModelCollection(BaseModel): +class BaseModel(BaseModel): """ This Model class combines the collection of valid peaks from BasePeak into a regression model of type lmfit.model.CompositeModel @@ -116,10 +99,31 @@ def check_lmfit_model(self) -> "BaseModelCollection": self.lmfit_model = lmfit_model return self - def construct_lmfit_model(self, peaks, default_peaks) -> "Model": + def construct_lmfit_model(self, peaks, default_peaks) -> Optional["Model"]: peak_names = peaks.split(SEP) base_peaks = [default_peaks[i] for i in peak_names if i in default_peaks] if not base_peaks: - return GaussianModel() - lmfit_model = construct_lmfit_model_from_base_peaks(base_peaks) + return None + base_peaks_lmfit = [i.lmfit_model for i in base_peaks] + lmfit_model = construct_lmfit_model_from_components(base_peaks_lmfit) return lmfit_model + + +def get_default_models() -> Dict[str, BasePeak]: + settings = load_default_peak_toml_files() + default_peaks = get_default_peaks() + models_settings = {k: val.get("models") for k, val in settings.items()} + base_models = {} + for model_name, model_peaks in models_settings.items(): + base_models[model_name] = BaseModel( + name=model_name, peaks=model_peaks, default_peaks=default_peaks + ) + return base_models + + +def main(): + models = get_default_models() + + +if __name__ == "__main__": + main() diff --git a/src/raman_fitting/deconvolution_models/base_peak.py b/src/raman_fitting/deconvolution_models/base_peak.py index a2462d2..61b7bc0 100644 --- a/src/raman_fitting/deconvolution_models/base_peak.py +++ b/src/raman_fitting/deconvolution_models/base_peak.py @@ -1,31 +1,21 @@ -import inspect -from functools import partialmethod - -from warnings import warn from enum import StrEnum - -from lmfit import Parameter, Parameters - -from lmfit.models import GaussianModel, LorentzianModel, Model, VoigtModel -from .lmfit_models import LMFIT_MODEL_MAPPER, LMFitParameterHints, parmeter_to_dict -from ..config.filepath_helper import load_default_peak_toml_files - -from typing import List, Literal, Optional, Dict, final +from typing import List, Optional, Dict from pydantic import ( BaseModel, ConfigDict, InstanceOf, Field, - ValidationError, - ValidationInfo, field_validator, model_validator, ) -from pytest import param +from lmfit import Parameters +from lmfit.models import Model +from .lmfit import LMFIT_MODEL_MAPPER, LMFitParameterHints, parmeter_to_dict +from ..config.filepath_helper import load_default_peak_toml_files -param_hint_dict = Dict[str, Dict[str, Optional[float | bool | str]]] +ParamHintDict = Dict[str, Dict[str, Optional[float | bool | str]]] class BasePeakWarning(UserWarning): # pragma: no cover @@ -35,13 +25,6 @@ class BasePeakWarning(UserWarning): # pragma: no cover PEAK_TYPE_OPTIONS = StrEnum("PEAK_TYPE_OPTIONS", ["Lorentzian", "Gaussian", "Voigt"]) -LMFIT_MODEL_MAPPER = { - "Lorentzian": LorentzianModel, - "Gaussian": GaussianModel, - "Voigt": VoigtModel, -} - - def get_lmfit_model_from_peak_type(peak_type: str, prefix: str = "") -> Optional[Model]: """returns the lmfit model instance according to the chosen peak type and sets the prefix from peak_name""" model = None @@ -117,9 +100,7 @@ class New_peak(metaclass=BasePeak): model_config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True) peak_name: str - param_hints: Optional[ - Parameters | List[LMFitParameterHints] | param_hint_dict - ] = None + param_hints: Optional[Parameters | List[LMFitParameterHints] | ParamHintDict] = None peak_type: Optional[str] = None is_substrate: Optional[bool] = False is_for_normalization: Optional[bool] = False @@ -148,7 +129,7 @@ def check_peak_type(cls, v: Optional[str]) -> Optional[str]: @field_validator("param_hints") @classmethod def check_param_hints( - cls, v: Optional[Parameters | List[LMFitParameterHints] | param_hint_dict] + cls, v: Optional[Parameters | List[LMFitParameterHints] | ParamHintDict] ) -> Optional[Parameters]: if v is None: return v diff --git a/src/raman_fitting/deconvolution_models/fit_models.py b/src/raman_fitting/deconvolution_models/fit_models.py index 1a6f455..c74bcbe 100644 --- a/src/raman_fitting/deconvolution_models/fit_models.py +++ b/src/raman_fitting/deconvolution_models/fit_models.py @@ -4,8 +4,8 @@ import pandas as pd -from ..processing.spectrum_constructor import SpectrumDataCollection, SpectrumDataLoader -from .init_models import InitializeModels +# from ..processing.spectrum_constructor import SpectrumDataCollection, SpectrumDataLoader +# from .init_models import InitializeModels logger = logging.getLogger(__name__) @@ -19,11 +19,11 @@ class Fitter: fit_windows = ["1st_order", "2nd_order"] - def __init__(self, spectra_arg, RamanModels=InitializeModels(), start_fit=True): + def __init__(self, spectra_arg, models=None, start_fit=True): self._qcnm = self.__class__.__qualname__ logger.debug(f"{self._qcnm} is called with spectrum\n\t{spectra_arg}\n") self.start_fit = start_fit - self.models = RamanModels + self.models = models self.spectra_arg = spectra_arg self.spectra = spectra_arg diff --git a/src/raman_fitting/deconvolution_models/init_models.py b/src/raman_fitting/deconvolution_models/init_models.py index e0426a3..a18b1c4 100644 --- a/src/raman_fitting/deconvolution_models/init_models.py +++ b/src/raman_fitting/deconvolution_models/init_models.py @@ -1,12 +1,7 @@ from dataclasses import dataclass import logging -from warnings import warn - -from lmfit import Model from raman_fitting.config.filepath_helper import load_default_peak_toml_files - - from raman_fitting.deconvolution_models.base_model import SEP from raman_fitting.deconvolution_models.base_peak import BasePeak diff --git a/src/raman_fitting/deconvolution_models/lmfit_models.py b/src/raman_fitting/deconvolution_models/lmfit.py similarity index 91% rename from src/raman_fitting/deconvolution_models/lmfit_models.py rename to src/raman_fitting/deconvolution_models/lmfit.py index 6f93c6b..c1bce3b 100644 --- a/src/raman_fitting/deconvolution_models/lmfit_models.py +++ b/src/raman_fitting/deconvolution_models/lmfit.py @@ -1,33 +1,17 @@ from ast import main -import inspect -from functools import partialmethod import math - -from os import name -from pyexpat import model - -from unittest.mock import DEFAULT -from warnings import warn from enum import StrEnum +from typing import List, Optional, Dict -from lmfit import Parameter, Parameters - +from lmfit import Parameter from lmfit.models import GaussianModel, LorentzianModel, Model, VoigtModel -from typing import List, Literal, Optional, Dict, final -import numpy - from pydantic import ( BaseModel, ConfigDict, - InstanceOf, Field, - ValidationError, - ValidationInfo, - field_validator, model_validator, ) -from pytest import param param_hint_dict = Dict[str, Dict[str, Optional[float | bool | str]]] @@ -42,6 +26,13 @@ class BasePeakWarning(UserWarning): # pragma: no cover LMFIT_PARAM_KWARGS = ("value", "vary", "min", "max", "expr") +LMFIT_MODEL_MAPPER = { + "Lorentzian": LorentzianModel, + "Gaussian": GaussianModel, + "Voigt": VoigtModel, +} + + class LMFitParameterHints(BaseModel): """ https://github.com/lmfit/lmfit-py/blob/master/lmfit/model.py#L566 @@ -130,15 +121,14 @@ def check_construct_parameter(self) -> "LMFitParameterHints": return self -DEFAULT_GAMMA_PARAM_HINT = LMFitParameterHints( - name="gamma", value=1, min=1e-05, max=70, vary=False -) - -LMFIT_MODEL_MAPPER = { - "Lorentzian": LorentzianModel, - "Gaussian": GaussianModel, - "Voigt": VoigtModel, -} +def construct_lmfit_model_from_components(models: List[Model]) -> "Model": + """ + Construct the lmfit model from a collection of (known) peaks + """ + if not models: + raise ValueError("No peaks given to construct lmfit model from.") + lmfit_composite_model = sum(models, models.pop()) + return lmfit_composite_model def parmeter_to_dict(parameter: Parameter) -> dict: @@ -147,6 +137,11 @@ def parmeter_to_dict(parameter: Parameter) -> dict: return ret +DEFAULT_GAMMA_PARAM_HINT = LMFitParameterHints( + name="gamma", value=1, min=1e-05, max=70, vary=False +) + + def main(): breakpoint() diff --git a/src/raman_fitting/deconvolution_models/model_config.py b/src/raman_fitting/deconvolution_models/model_config.py deleted file mode 100644 index b0aaa9e..0000000 --- a/src/raman_fitting/deconvolution_models/model_config.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -Created on Sun May 30 12:35:58 2021 - -@author: DW -""" - -from raman_fitting.model_validation import PeakModelValidator - - -class ModelConfigurator: - standard_config_file = "model_config_standard.cfg" - - def __init__(self, **kwargs): - self._kwargs = kwargs - - def find_user_config_files(self): - pass - - def file_handler(self): - pass - - def standard_valid_models(self): - peak_collection = PeakModelValidator() diff --git a/src/raman_fitting/deconvolution_models/peak_validation.py b/src/raman_fitting/deconvolution_models/peak_validation.py index 2b0fd99..b524846 100644 --- a/src/raman_fitting/deconvolution_models/peak_validation.py +++ b/src/raman_fitting/deconvolution_models/peak_validation.py @@ -6,22 +6,18 @@ @author: zmg """ -import inspect import logging from collections import namedtuple -from itertools import groupby -from pathlib import Path from typing import Tuple from warnings import warn import matplotlib.pyplot as plt -import pandas as pd -from lmfit import Parameters -from .. import __package_name__ -from .base_peak import BasePeak +logger = logging.getLogger(__file__) -logger = logging.getLogger(__package_name__) + +CMAP_OPTIONS_DEFAULT = ("Dark2", "tab20") +fallback_color = (0.4, 0.4, 0.4, 1.0) class PeakValidationWarning(UserWarning): @@ -48,16 +44,10 @@ class PeakModelValidator: """ - # _standard_modules = [first_order_peaks, second_order_peaks, normalization_peaks] - BASE_PEAK = BasePeak - ModelValidation = namedtuple( "ModelValidation", "valid peak_group model_inst message" ) - CMAP_OPTIONS_DEFAULT = ("Dark2", "tab20") - fallback_color = (0.4, 0.4, 0.4, 1.0) - debug = False def __init__(self, *args, cmap_options=CMAP_OPTIONS_DEFAULT, **kwargs): @@ -77,38 +67,6 @@ def __init__(self, *args, cmap_options=CMAP_OPTIONS_DEFAULT, **kwargs): self.model_dict = self.get_model_dict(self.lmfit_models) self.options = self.model_dict.keys() - def get_subclasses_from_base(self, _BaseClass): - """Finds subclasses of the BasePeak metaclass, these should give already valid models""" - - _all_subclasses = [] - if inspect.isclass(_BaseClass): - if hasattr(_BaseClass, "subclasses"): - _all_subclasses = _BaseClass.subclasses - elif hasattr(_BaseClass, "__subclassess__"): - _all_subclasses = _BaseClass.__subclasses__ - else: - warn( - f"\nNo baseclasses were found for {str(_BaseClass)}:\n missing attributes", - NotFoundAnyModelsWarning, - ) - else: - warn( - f"\nNo baseclasses were found for {str(_BaseClass)}:\n is not a class", - NotFoundAnyModelsWarning, - ) - - if not _all_subclasses: - warn( - f"\nNo baseclasses were found in inspected modules for {str(_BaseClass)}:\n", - NotFoundAnyModelsWarning, - ) - - return _all_subclasses - - def _inspect_modules_for_classes(self): - """Optional method Inspect other modules for subclasses""" - pass - def validation_inspect_models(self, inspect_models: list = []): """Validates each member of a list for making a valid model instance""" _model_validations = [] @@ -159,35 +117,6 @@ def sort_selected_models(self, value): _sorted = sorted(_sorted, key=lambda x: x.peak_group) return _sorted - def validate_model_instance(self, value): - """ - Returns a boolean, model and message depending on the validation of the model class. - Invalid classes can raise warnings, but exception only when no valid models are found. - """ - - try: - _inst = value() - except Exception as e: - _err = f"Unable to initialize model {value},\n{e}" - warn(f"{_err}", CanNotInitializeModelWarning) - return (False, value, _err) - - for field in self.BASE_PEAK._fields: - if not hasattr(_inst, field): - return (False, value, f"instance {_inst} has no attr {field}.\n") - if not getattr(_inst, field): - return (False, value, f"instance {_inst}, {field} is None.\n") - if "param_hints" in field: - _settings = getattr(_inst, field) - _center = _settings.get("center", None) - if not _center: - return ( - False, - value, - f"instance {_inst}, settings {_settings} center is None.\n", - ) - return (True, _inst, f"{_inst} is a valid model") - @staticmethod def get_cmap_list( lst, cmap_options: Tuple = (), fallback_color: Tuple = () @@ -233,40 +162,6 @@ def assign_colors_to_lmfit_mod_inst(self, selected_models: list): lmfit_models.append(_m_inst) return lmfit_models - def add_standard_init_params(self): - self.standard_init_params = Parameters() - self.standard_init_params.add_many(*BasePeak._params_guesses_base) - - def add_model_names_var_names(self, lmfit_models): - _mod_param_names = { - i.lmfit_model.name: i.lmfit_model.param_names for i in lmfit_models - } - return _mod_param_names - - def get_model_dict(self, lmfit_models): - model_dict = {i.__class__.__name__: i for i in lmfit_models} - return model_dict - - def get_dict(self): - return { - i.__module__ + ", " + i.__class__.__name__: i for i in self.lmfit_models - } - - def __getattr__(self, name): - try: - _options = self.__getattribute__("options") - if name in _options: - return self.model_dict.get(name, None) - raise AttributeError( - f'Chosen name "{name}" not in options: "{", ".join(_options)}".' - ) - except AttributeError: - raise AttributeError(f'Chosen name "{name}" not in attributes') - - def __iter__(self): - for mod_inst in self.lmfit_models: - yield mod_inst - def __repr__(self): _repr = "Validated Peak model collection" if self.selected_models: diff --git a/tests/deconvolution_models/test_base_model.py b/tests/deconvolution_models/test_base_model.py index 2683e6c..df4f6e2 100644 --- a/tests/deconvolution_models/test_base_model.py +++ b/tests/deconvolution_models/test_base_model.py @@ -6,16 +6,13 @@ import unittest from functools import partial -from operator import itemgetter -import pytest -from lmfit import Model +from pydantic import ValidationError from raman_fitting.deconvolution_models.base_model import ( SUBSTRATE_PEAK, - BaseModelCollection, + BaseModel, ) -from pydantic import ValidationError SUBSTRATE_PREFIX = SUBSTRATE_PEAK.split("peak")[0] @@ -28,13 +25,13 @@ def helper_get_list_components(bm): class TestBaseModel(unittest.TestCase): def test_empty_base_model(self): - self.assertRaises(ValidationError, BaseModelCollection) - self.assertRaises(ValidationError, BaseModelCollection, name="Test_empty") - self.assertRaises(ValidationError, BaseModelCollection, peaks="A+B") - bm = BaseModelCollection(name="Test_empty", peaks="A+B") + self.assertRaises(ValidationError, BaseModel) + self.assertRaises(ValidationError, BaseModel, name="Test_empty") + self.assertRaises(ValidationError, BaseModel, peaks="A+B") + bm = BaseModel(name="Test_empty", peaks="A+B") def test_base_model_2peaks(self): - bm = BaseModelCollection(name="Test_2peaks", peaks="K2+D+G") + bm = BaseModel(name="Test_2peaks", peaks="K2+D+G") print(bm) self.assertSetEqual(set(helper_get_list_components(bm)), set(["D_", "G_"])) @@ -46,9 +43,7 @@ def test_base_model_2peaks(self): self.assertSetEqual(set(helper_get_list_components(bm)), set(["D_", "G_"])) def test_base_model_wrong_chars_model_name(self): - bm = BaseModelCollection( - name="Test_wrong_chars", peaks="K2+---////+ +7 +K1111+1D+D2" - ) + bm = BaseModel(name="Test_wrong_chars", peaks="K2+---////+ +7 +K1111+1D+D2") self.assertSetEqual(set(helper_get_list_components(bm)), set(["D2_"])) bm.add_substrate() self.assertSetEqual( diff --git a/tests/deconvolution_models/test_fit_models.py b/tests/deconvolution_models/test_fit_models.py index 2a1a9fa..eb073c9 100644 --- a/tests/deconvolution_models/test_fit_models.py +++ b/tests/deconvolution_models/test_fit_models.py @@ -10,19 +10,9 @@ import unittest import pandas as pd -import pytest -from lmfit import Model -import raman_fitting from raman_fitting.deconvolution_models.fit_models import Fitter, PrepareParams -# try: -# import raman_fitting - - -# except Exception as e: -# print(f'pytest file {__file__}, {__name__} error {e}') - class TestFitter(unittest.TestCase): def test_empty_Fitter(self):