Skip to content

Commit

Permalink
fix: Renamed MLModelType to ModelType
Browse files Browse the repository at this point in the history
Signed-off-by: Clara De Smet <[email protected]>
  • Loading branch information
clara-de-smet committed Oct 2, 2024
1 parent e2fafd2 commit 53082af
Show file tree
Hide file tree
Showing 12 changed files with 67 additions and 68 deletions.
3 changes: 1 addition & 2 deletions openstef/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from enum import Enum


# TODO replace this with ModelType (MLModelType == Machine Learning model type)
class MLModelType(Enum):
class ModelType(Enum):
XGB = "xgb"
XGB_QUANTILE = "xgb_quantile"
XGB_MULTIOUTPUT_QUANTILE = "xgb_multioutput_quantile"
Expand Down
40 changes: 20 additions & 20 deletions openstef/model/model_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import structlog

from openstef.enums import MLModelType
from openstef.enums import ModelType
from openstef.model.regressors.arima import ARIMAOpenstfRegressor
from openstef.model.regressors.custom_regressor import is_custom_type, load_custom_model
from openstef.model.regressors.lgbm import LGBMOpenstfRegressor
Expand All @@ -29,7 +29,7 @@
logger = structlog.get_logger(__name__)

valid_model_kwargs = {
MLModelType.XGB: [
ModelType.XGB: [
"n_estimators",
"objective",
"max_depth",
Expand Down Expand Up @@ -60,7 +60,7 @@
"validate_parameters",
"early_stopping_rounds",
],
MLModelType.LGB: [
ModelType.LGB: [
"boosting_type",
"objective",
"num_leaves",
Expand All @@ -82,7 +82,7 @@
"importance_type",
"early_stopping_rounds",
],
MLModelType.XGB_QUANTILE: [
ModelType.XGB_QUANTILE: [
"quantiles",
"gamma",
"colsample_bytree",
Expand All @@ -91,7 +91,7 @@
"max_depth",
"early_stopping_rounds",
],
MLModelType.XGB_MULTIOUTPUT_QUANTILE: [
ModelType.XGB_MULTIOUTPUT_QUANTILE: [
"quantiles",
"gamma",
"colsample_bytree",
Expand All @@ -101,23 +101,23 @@
"early_stopping_rounds",
"arctan_smoothing",
],
MLModelType.LINEAR: [
ModelType.LINEAR: [
"missing_values",
"imputation_strategy",
"fill_value",
],
MLModelType.FLATLINER: [
ModelType.FLATLINER: [
"quantiles",
],
MLModelType.LINEAR_QUANTILE: [
ModelType.LINEAR_QUANTILE: [
"alpha",
"quantiles",
"solver",
"missing_values",
"imputation_strategy",
"fill_value",
],
MLModelType.ARIMA: [
ModelType.ARIMA: [
"backtest_max_horizon",
"order",
"seasonal_order",
Expand All @@ -131,18 +131,18 @@ class ModelCreator:

# Set object mapping
MODEL_CONSTRUCTORS = {
MLModelType.XGB: XGBOpenstfRegressor,
MLModelType.LGB: LGBMOpenstfRegressor,
MLModelType.XGB_QUANTILE: XGBQuantileOpenstfRegressor,
MLModelType.XGB_MULTIOUTPUT_QUANTILE: XGBMultiOutputQuantileOpenstfRegressor,
MLModelType.LINEAR: LinearOpenstfRegressor,
MLModelType.LINEAR_QUANTILE: LinearQuantileOpenstfRegressor,
MLModelType.ARIMA: ARIMAOpenstfRegressor,
MLModelType.FLATLINER: FlatlinerRegressor,
ModelType.XGB: XGBOpenstfRegressor,
ModelType.LGB: LGBMOpenstfRegressor,
ModelType.XGB_QUANTILE: XGBQuantileOpenstfRegressor,
ModelType.XGB_MULTIOUTPUT_QUANTILE: XGBMultiOutputQuantileOpenstfRegressor,
ModelType.LINEAR: LinearOpenstfRegressor,
ModelType.LINEAR_QUANTILE: LinearQuantileOpenstfRegressor,
ModelType.ARIMA: ARIMAOpenstfRegressor,
ModelType.FLATLINER: FlatlinerRegressor,
}

@staticmethod
def create_model(model_type: Union[MLModelType, str], **kwargs) -> OpenstfRegressor:
def create_model(model_type: Union[ModelType, str], **kwargs) -> OpenstfRegressor:
"""Create a machine learning model based on model type.
Args:
Expand All @@ -163,7 +163,7 @@ def create_model(model_type: Union[MLModelType, str], **kwargs) -> OpenstfRegres
model_class = load_custom_model(model_type)
valid_kwargs = model_class.valid_kwargs()
else:
model_type = MLModelType(model_type)
model_type = ModelType(model_type)
model_class = ModelCreator.MODEL_CONSTRUCTORS[model_type]
valid_kwargs = valid_model_kwargs[model_type]
# Check if model as imported
Expand All @@ -174,7 +174,7 @@ def create_model(model_type: Union[MLModelType, str], **kwargs) -> OpenstfRegres
"Please refer to the ReadMe for instructions"
)
except ValueError as e:
valid_types = [t.value for t in MLModelType]
valid_types = [t.value for t in ModelType]
raise NotImplementedError(
f"No constructor for '{model_type}', "
f"valid model_types are: {valid_types} "
Expand Down
14 changes: 7 additions & 7 deletions openstef/model/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import optuna
import pandas as pd

from openstef.enums import MLModelType
from openstef.enums import ModelType
from openstef.metrics import metrics
from openstef.metrics.reporter import Report, Reporter
from openstef.model.regressors.regressor import OpenstfRegressor
Expand Down Expand Up @@ -245,7 +245,7 @@ def get_default_values(cls) -> dict:
class XGBRegressorObjective(RegressorObjective):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_type = MLModelType.XGB
self.model_type = ModelType.XGB

# extend the parameters with the model specific ones per implementation
def get_params(self, trial: optuna.trial.FrozenTrial) -> dict:
Expand Down Expand Up @@ -282,7 +282,7 @@ def get_default_values(cls) -> dict:
class LGBRegressorObjective(RegressorObjective):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_type = MLModelType.LGB
self.model_type = ModelType.LGB

def get_params(self, trial: optuna.trial.FrozenTrial) -> dict:
"""Get parameters for LGB Regressor Objective with objective specific parameters.
Expand Down Expand Up @@ -323,7 +323,7 @@ def get_pruning_callback(self, trial: optuna.trial.FrozenTrial):
class XGBQuantileRegressorObjective(RegressorObjective):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_type = MLModelType.XGB_QUANTILE
self.model_type = ModelType.XGB_QUANTILE

def get_params(self, trial: optuna.trial.FrozenTrial) -> dict:
"""Get parameters for XGBQuantile Regressor Objective with objective specific parameters.
Expand Down Expand Up @@ -352,7 +352,7 @@ def get_pruning_callback(self, trial: optuna.trial.FrozenTrial):
class XGBMultioutputQuantileRegressorObjective(RegressorObjective):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_type = MLModelType.XGB_QUANTILE
self.model_type = ModelType.XGB_QUANTILE

def get_params(self, trial: optuna.trial.FrozenTrial) -> dict:
"""Get parameters for XGB Multioutput Quantile Regressor Objective with objective specific parameters.
Expand Down Expand Up @@ -382,7 +382,7 @@ def get_pruning_callback(self, trial: optuna.trial.FrozenTrial):
class LinearRegressorObjective(RegressorObjective):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_type = MLModelType.LINEAR
self.model_type = ModelType.LINEAR

def get_params(self, trial: optuna.trial.FrozenTrial) -> dict:
"""Get parameters for Linear Regressor Objective with objective specific parameters.
Expand All @@ -405,7 +405,7 @@ def get_params(self, trial: optuna.trial.FrozenTrial) -> dict:
class ARIMARegressorObjective(RegressorObjective):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_type = MLModelType.ARIMA
self.model_type = ModelType.ARIMA

def get_params(self, trial: optuna.trial.FrozenTrial) -> dict:
"""Get parameters for ARIMA Regressor Objective with objective specific parameters.
Expand Down
22 changes: 11 additions & 11 deletions openstef/model/objective_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from typing import Union

from openstef.enums import MLModelType
from openstef.enums import ModelType
from openstef.model.objective import (
ARIMARegressorObjective,
LGBRegressorObjective,
Expand All @@ -22,17 +22,17 @@

class ObjectiveCreator:
OBJECTIVES = {
MLModelType.XGB: XGBRegressorObjective,
MLModelType.LGB: LGBRegressorObjective,
MLModelType.XGB_QUANTILE: XGBQuantileRegressorObjective,
MLModelType.XGB_MULTIOUTPUT_QUANTILE: XGBMultioutputQuantileRegressorObjective,
MLModelType.LINEAR: LinearRegressorObjective,
MLModelType.LINEAR_QUANTILE: LinearRegressorObjective,
MLModelType.ARIMA: ARIMARegressorObjective,
ModelType.XGB: XGBRegressorObjective,
ModelType.LGB: LGBRegressorObjective,
ModelType.XGB_QUANTILE: XGBQuantileRegressorObjective,
ModelType.XGB_MULTIOUTPUT_QUANTILE: XGBMultioutputQuantileRegressorObjective,
ModelType.LINEAR: LinearRegressorObjective,
ModelType.LINEAR_QUANTILE: LinearRegressorObjective,
ModelType.ARIMA: ARIMARegressorObjective,
}

@staticmethod
def create_objective(model_type: Union[MLModelType, str]) -> RegressorObjective:
def create_objective(model_type: Union[ModelType, str]) -> RegressorObjective:
"""Create an objective function based on model type.
Args:
Expand All @@ -51,10 +51,10 @@ def create_objective(model_type: Union[MLModelType, str]) -> RegressorObjective:
if is_custom_type(model_type):
objective = create_custom_objective(model_type)
else:
model_type = MLModelType(model_type)
model_type = ModelType(model_type)
objective = ObjectiveCreator.OBJECTIVES[model_type]
except ValueError as e:
valid_types = [t.value for t in MLModelType]
valid_types = [t.value for t in ModelType]
raise NotImplementedError(
f"No objective for '{model_type}', "
f"valid model_types are: {valid_types}"
Expand Down
6 changes: 3 additions & 3 deletions openstef/tasks/calculate_kpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import structlog

from openstef.data_classes.prediction_job import PredictionJobDataClass
from openstef.enums import MLModelType
from openstef.enums import ModelType
from openstef.exceptions import NoPredictedLoadError, NoRealisedLoadError
from openstef.metrics import metrics
from openstef.settings import Settings
Expand All @@ -42,7 +42,7 @@
THRESHOLD_OPTIMIZING = 0.50


def main(model_type: MLModelType = None, config=None, database=None) -> None:
def main(model_type: ModelType = None, config=None, database=None) -> None:
taskname = Path(__file__).name.replace(".py", "")

if database is None or config is None:
Expand All @@ -52,7 +52,7 @@ def main(model_type: MLModelType = None, config=None, database=None) -> None:
)

if model_type is None:
model_type = [ml.value for ml in MLModelType]
model_type = [ml.value for ml in ModelType]

with TaskContext(taskname, config, database) as context:
# Set start and end time
Expand Down
4 changes: 2 additions & 2 deletions openstef/tasks/create_components_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import structlog

from openstef.data_classes.prediction_job import PredictionJobDataClass
from openstef.enums import MLModelType
from openstef.enums import ModelType
from openstef.exceptions import ComponentForecastTooShortHorizonError
from openstef.pipeline.create_component_forecast import (
create_components_forecast_pipeline,
Expand Down Expand Up @@ -150,7 +150,7 @@ def main(config: object = None, database: object = None, **kwargs):
)

with TaskContext(taskname, config, database) as context:
model_type = [ml.value for ml in MLModelType]
model_type = [ml.value for ml in ModelType]

PredictionJobLoop(
context,
Expand Down
4 changes: 2 additions & 2 deletions openstef/tasks/create_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pathlib import Path

from openstef.data_classes.prediction_job import PredictionJobDataClass
from openstef.enums import MLModelType, PipelineType
from openstef.enums import ModelType, PipelineType
from openstef.exceptions import InputDataOngoingZeroFlatlinerError
from openstef.pipeline.create_forecast import create_forecast_pipeline
from openstef.tasks.utils.predictionjobloop import PredictionJobLoop
Expand Down Expand Up @@ -129,7 +129,7 @@ def main(model_type=None, config=None, database=None, **kwargs):

with TaskContext(taskname, config, database) as context:
if model_type is None:
model_type = [ml.value for ml in MLModelType]
model_type = [ml.value for ml in ModelType]

PredictionJobLoop(context, model_type=model_type).map(
create_forecast_task, context, **kwargs
Expand Down
4 changes: 2 additions & 2 deletions openstef/tasks/optimize_hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pathlib import Path

from openstef.data_classes.prediction_job import PredictionJobDataClass
from openstef.enums import MLModelType, PipelineType
from openstef.enums import ModelType, PipelineType
from openstef.model.serializer import MLflowSerializer
from openstef.monitoring import teams
from openstef.pipeline.optimize_hyperparameters import optimize_hyperparameters_pipeline
Expand Down Expand Up @@ -124,7 +124,7 @@ def main(config=None, database=None):
)

with TaskContext(taskname, config, database) as context:
model_type = [ml.value for ml in MLModelType]
model_type = [ml.value for ml in ModelType]

PredictionJobLoop(context, model_type=model_type).map(
optimize_hyperparameters_task, context
Expand Down
4 changes: 2 additions & 2 deletions openstef/tasks/split_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

import openstef.monitoring.teams as monitoring
from openstef.data_classes.prediction_job import PredictionJobDataClass
from openstef.enums import MLModelType
from openstef.enums import ModelType
from openstef.settings import Settings
from openstef.tasks.utils.predictionjobloop import PredictionJobLoop
from openstef.tasks.utils.taskcontext import TaskContext
Expand All @@ -51,7 +51,7 @@ def main(config=None, database=None):
)

with TaskContext(taskname, config, database) as context:
model_type = [ml.value for ml in MLModelType]
model_type = [ml.value for ml in ModelType]

PredictionJobLoop(
context,
Expand Down
4 changes: 2 additions & 2 deletions openstef/tasks/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pathlib import Path

from openstef.data_classes.prediction_job import PredictionJobDataClass
from openstef.enums import MLModelType, PipelineType
from openstef.enums import ModelType, PipelineType
from openstef.exceptions import (
InputDataOngoingZeroFlatlinerError,
SkipSaveTrainingForecasts,
Expand Down Expand Up @@ -179,7 +179,7 @@ def main(model_type=None, config=None, database=None):
)

if model_type is None:
model_type = [ml.value for ml in MLModelType]
model_type = [ml.value for ml in ModelType]

taskname = Path(__file__).name.replace(".py", "")
datetime_now = datetime.utcnow()
Expand Down
Loading

0 comments on commit 53082af

Please sign in to comment.