Skip to content

Commit

Permalink
chore: fix imports
Browse files Browse the repository at this point in the history
Signed-off-by: David Wallace <[email protected]>
  • Loading branch information
MyPyDavid committed Nov 26, 2023
1 parent 08114f2 commit 57ac29a
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 245 deletions.
58 changes: 31 additions & 27 deletions src/raman_fitting/deconvolution_models/base_model.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,20 @@
""" The members of the validated collection of BasePeaks are assembled here into fitting Models"""
import logging

from typing import Optional, Dict
from warnings import warn
from typing import List, Optional, Dict

from lmfit.models import Model, GaussianModel

from raman_fitting.deconvolution_models.peak_validation import PeakModelValidator
from .base_peak import BasePeak, get_default_peaks

from lmfit.models import Model
from pydantic import (
BaseModel,
PositiveInt,
Field,
ConfigDict,
model_validator,
ValidationInfo,
field_validator,
ValidationError,
)

from .base_peak import BasePeak, get_default_peaks
from .lmfit import construct_lmfit_model_from_components
from ..config.filepath_helper import load_default_peak_toml_files

logger = logging.getLogger(__name__)

SUBSTRATE_PEAK = "Si1_peak"
Expand All @@ -31,19 +26,7 @@ class BaseModelWarning(UserWarning):
pass


def construct_lmfit_model_from_base_peaks(base_peaks: List[BasePeak]) -> Model:
"""
Construct the lmfit model from a collection of (known) peaks
"""
if not base_peaks:
raise ValueError("No peaks given to construct lmfit model from.")
lmfit_composite_model = sum(
map(lambda x: x.lmfit_model, base_peaks), base_peaks.pop().lmfit_model
)
return lmfit_composite_model


class BaseModelCollection(BaseModel):
class BaseModel(BaseModel):
"""
This Model class combines the collection of valid peaks from BasePeak into a regression model
of type lmfit.model.CompositeModel
Expand Down Expand Up @@ -116,10 +99,31 @@ def check_lmfit_model(self) -> "BaseModelCollection":
self.lmfit_model = lmfit_model
return self

def construct_lmfit_model(self, peaks, default_peaks) -> "Model":
def construct_lmfit_model(self, peaks, default_peaks) -> Optional["Model"]:
peak_names = peaks.split(SEP)
base_peaks = [default_peaks[i] for i in peak_names if i in default_peaks]
if not base_peaks:
return GaussianModel()
lmfit_model = construct_lmfit_model_from_base_peaks(base_peaks)
return None
base_peaks_lmfit = [i.lmfit_model for i in base_peaks]
lmfit_model = construct_lmfit_model_from_components(base_peaks_lmfit)
return lmfit_model


def get_default_models() -> Dict[str, BasePeak]:
settings = load_default_peak_toml_files()
default_peaks = get_default_peaks()
models_settings = {k: val.get("models") for k, val in settings.items()}
base_models = {}
for model_name, model_peaks in models_settings.items():
base_models[model_name] = BaseModel(
name=model_name, peaks=model_peaks, default_peaks=default_peaks
)
return base_models


def main():
models = get_default_models()


if __name__ == "__main__":
main()
35 changes: 8 additions & 27 deletions src/raman_fitting/deconvolution_models/base_peak.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,21 @@
import inspect
from functools import partialmethod

from warnings import warn
from enum import StrEnum

from lmfit import Parameter, Parameters

from lmfit.models import GaussianModel, LorentzianModel, Model, VoigtModel
from .lmfit_models import LMFIT_MODEL_MAPPER, LMFitParameterHints, parmeter_to_dict
from ..config.filepath_helper import load_default_peak_toml_files

from typing import List, Literal, Optional, Dict, final
from typing import List, Optional, Dict

from pydantic import (
BaseModel,
ConfigDict,
InstanceOf,
Field,
ValidationError,
ValidationInfo,
field_validator,
model_validator,
)
from pytest import param
from lmfit import Parameters
from lmfit.models import Model

from .lmfit import LMFIT_MODEL_MAPPER, LMFitParameterHints, parmeter_to_dict
from ..config.filepath_helper import load_default_peak_toml_files

param_hint_dict = Dict[str, Dict[str, Optional[float | bool | str]]]
ParamHintDict = Dict[str, Dict[str, Optional[float | bool | str]]]


class BasePeakWarning(UserWarning): # pragma: no cover
Expand All @@ -35,13 +25,6 @@ class BasePeakWarning(UserWarning): # pragma: no cover
PEAK_TYPE_OPTIONS = StrEnum("PEAK_TYPE_OPTIONS", ["Lorentzian", "Gaussian", "Voigt"])


LMFIT_MODEL_MAPPER = {
"Lorentzian": LorentzianModel,
"Gaussian": GaussianModel,
"Voigt": VoigtModel,
}


def get_lmfit_model_from_peak_type(peak_type: str, prefix: str = "") -> Optional[Model]:
"""returns the lmfit model instance according to the chosen peak type and sets the prefix from peak_name"""
model = None
Expand Down Expand Up @@ -117,9 +100,7 @@ class New_peak(metaclass=BasePeak):
model_config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True)

peak_name: str
param_hints: Optional[
Parameters | List[LMFitParameterHints] | param_hint_dict
] = None
param_hints: Optional[Parameters | List[LMFitParameterHints] | ParamHintDict] = None
peak_type: Optional[str] = None
is_substrate: Optional[bool] = False
is_for_normalization: Optional[bool] = False
Expand Down Expand Up @@ -148,7 +129,7 @@ def check_peak_type(cls, v: Optional[str]) -> Optional[str]:
@field_validator("param_hints")
@classmethod
def check_param_hints(
cls, v: Optional[Parameters | List[LMFitParameterHints] | param_hint_dict]
cls, v: Optional[Parameters | List[LMFitParameterHints] | ParamHintDict]
) -> Optional[Parameters]:
if v is None:
return v
Expand Down
8 changes: 4 additions & 4 deletions src/raman_fitting/deconvolution_models/fit_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import pandas as pd

from ..processing.spectrum_constructor import SpectrumDataCollection, SpectrumDataLoader
from .init_models import InitializeModels
# from ..processing.spectrum_constructor import SpectrumDataCollection, SpectrumDataLoader
# from .init_models import InitializeModels

logger = logging.getLogger(__name__)

Expand All @@ -19,11 +19,11 @@ class Fitter:

fit_windows = ["1st_order", "2nd_order"]

def __init__(self, spectra_arg, RamanModels=InitializeModels(), start_fit=True):
def __init__(self, spectra_arg, models=None, start_fit=True):
self._qcnm = self.__class__.__qualname__
logger.debug(f"{self._qcnm} is called with spectrum\n\t{spectra_arg}\n")
self.start_fit = start_fit
self.models = RamanModels
self.models = models

self.spectra_arg = spectra_arg
self.spectra = spectra_arg
Expand Down
5 changes: 0 additions & 5 deletions src/raman_fitting/deconvolution_models/init_models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
from dataclasses import dataclass
import logging
from warnings import warn

from lmfit import Model

from raman_fitting.config.filepath_helper import load_default_peak_toml_files


from raman_fitting.deconvolution_models.base_model import SEP
from raman_fitting.deconvolution_models.base_peak import BasePeak

Expand Down
Original file line number Diff line number Diff line change
@@ -1,33 +1,17 @@
from ast import main
import inspect
from functools import partialmethod
import math

from os import name
from pyexpat import model

from unittest.mock import DEFAULT
from warnings import warn
from enum import StrEnum
from typing import List, Optional, Dict

from lmfit import Parameter, Parameters

from lmfit import Parameter
from lmfit.models import GaussianModel, LorentzianModel, Model, VoigtModel

from typing import List, Literal, Optional, Dict, final
import numpy

from pydantic import (
BaseModel,
ConfigDict,
InstanceOf,
Field,
ValidationError,
ValidationInfo,
field_validator,
model_validator,
)
from pytest import param


param_hint_dict = Dict[str, Dict[str, Optional[float | bool | str]]]
Expand All @@ -42,6 +26,13 @@ class BasePeakWarning(UserWarning): # pragma: no cover
LMFIT_PARAM_KWARGS = ("value", "vary", "min", "max", "expr")


LMFIT_MODEL_MAPPER = {
"Lorentzian": LorentzianModel,
"Gaussian": GaussianModel,
"Voigt": VoigtModel,
}


class LMFitParameterHints(BaseModel):
"""
https://github.com/lmfit/lmfit-py/blob/master/lmfit/model.py#L566
Expand Down Expand Up @@ -130,15 +121,14 @@ def check_construct_parameter(self) -> "LMFitParameterHints":
return self


DEFAULT_GAMMA_PARAM_HINT = LMFitParameterHints(
name="gamma", value=1, min=1e-05, max=70, vary=False
)

LMFIT_MODEL_MAPPER = {
"Lorentzian": LorentzianModel,
"Gaussian": GaussianModel,
"Voigt": VoigtModel,
}
def construct_lmfit_model_from_components(models: List[Model]) -> "Model":
"""
Construct the lmfit model from a collection of (known) peaks
"""
if not models:
raise ValueError("No peaks given to construct lmfit model from.")
lmfit_composite_model = sum(models, models.pop())
return lmfit_composite_model


def parmeter_to_dict(parameter: Parameter) -> dict:
Expand All @@ -147,6 +137,11 @@ def parmeter_to_dict(parameter: Parameter) -> dict:
return ret


DEFAULT_GAMMA_PARAM_HINT = LMFitParameterHints(
name="gamma", value=1, min=1e-05, max=70, vary=False
)


def main():
breakpoint()

Expand Down
23 changes: 0 additions & 23 deletions src/raman_fitting/deconvolution_models/model_config.py

This file was deleted.

Loading

0 comments on commit 57ac29a

Please sign in to comment.