Skip to content

Commit

Permalink
fix: Pass kwargs to forecasting tasks (#547)
Browse files Browse the repository at this point in the history
  • Loading branch information
clara-de-smet authored Aug 28, 2024
1 parent 0bf8001 commit c38f226
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions openstef/tasks/create_basecase_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def create_basecase_forecast_task(
context.database.write_forecast(basecase_forecast, t_ahead_series=True)


def main(config: object = None, database: object = None):
def main(config: object = None, database: object = None, **kwargs):
taskname = Path(__file__).name.replace(".py", "")

if database is None or config is None:
Expand All @@ -110,7 +110,7 @@ def main(config: object = None, database: object = None):
model_type = ["xgb", "xgb_quantile", "lgb"]

PredictionJobLoop(context, model_type=model_type).map(
create_basecase_forecast_task, context
create_basecase_forecast_task, context, **kwargs
)


Expand Down
4 changes: 2 additions & 2 deletions openstef/tasks/create_components_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def create_components_forecast_task(
)


def main(config: object = None, database: object = None):
def main(config: object = None, database: object = None, **kwargs):
taskname = Path(__file__).name.replace(".py", "")

if database is None or config is None:
Expand All @@ -155,7 +155,7 @@ def main(config: object = None, database: object = None):
PredictionJobLoop(
context,
model_type=model_type,
).map(create_components_forecast_task, context)
).map(create_components_forecast_task, context, **kwargs)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions openstef/tasks/create_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def create_forecast_task(
context.database.write_forecast(forecast, t_ahead_series=True)


def main(model_type=None, config=None, database=None):
def main(model_type=None, config=None, database=None, **kwargs):
taskname = Path(__file__).name.replace(".py", "")

if database is None or config is None:
Expand All @@ -132,7 +132,7 @@ def main(model_type=None, config=None, database=None):
model_type = [ml.value for ml in MLModelType]

PredictionJobLoop(context, model_type=model_type).map(
create_forecast_task, context
create_forecast_task, context, **kwargs
)


Expand Down
4 changes: 2 additions & 2 deletions openstef/tasks/create_solar_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def fides(data: pd.DataFrame, all_forecasts: bool = False):
return forecast


def main(config=None, database=None):
def main(config=None, database=None, **kwargs):
taskname = Path(__file__).name.replace(".py", "")

if database is None or config is None:
Expand Down Expand Up @@ -245,7 +245,7 @@ def main(config=None, database=None):
)

PredictionJobLoop(context, prediction_jobs=prediction_jobs).map(
make_solar_prediction_pj, context
make_solar_prediction_pj, context, kwargs=kwargs
)


Expand Down

0 comments on commit c38f226

Please sign in to comment.