Skip to content

Commit

Permalink
Format Python code with Black
Browse files Browse the repository at this point in the history
Signed-off-by: black <[email protected]>
  • Loading branch information
actions-user committed Dec 8, 2023
1 parent 3614df6 commit cdb2a67
Showing 1 changed file with 23 additions and 16 deletions.
39 changes: 23 additions & 16 deletions test/unit/pipeline/test_pipeline_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,48 +686,55 @@ def test_train_pipeline_common_with_missing_custom_horizon(self):
) = train_pipeline_common(
self.pj, self.model_specs, self.train_input, horizons="custom_horizon"
)



@patch("openstef.pipeline.train_model.MLflowSerializer")
@patch("openstef.pipeline.train_model.train_model_pipeline_core")
def test_train_model_pipeline_with_default_train_horizons(self, mock_train_model_pipeline_core, mock_serializer):
def test_train_model_pipeline_with_default_train_horizons(
self, mock_train_model_pipeline_core, mock_serializer
):
# Arrange
mock_train_model_pipeline_core.return_value = 'a', 'b', 'c', 'd'
mock_train_model_pipeline_core.return_value = "a", "b", "c", "d"

# Act
train_model_pipeline(
pj=self.pj,
input_data=self.train_input,
check_old_model_age=False,
mlflow_tracking_uri="./test/unit/trained_models/mlruns",
artifact_folder=None
artifact_folder=None,
)

# Assert
self.pj.train_horizons_minutes == None
assert mock_train_model_pipeline_core.call_args.kwargs["horizons"] == DEFAULT_TRAIN_HORIZONS_HOURS

assert (
mock_train_model_pipeline_core.call_args.kwargs["horizons"]
== DEFAULT_TRAIN_HORIZONS_HOURS
)

@patch("openstef.pipeline.train_model.MLflowSerializer")
@patch("openstef.pipeline.train_model.train_model_pipeline_core")
def test_train_model_pipeline_with_custom_train_horizons(self, mock_train_model_pipeline_core, mock_serializer):
def test_train_model_pipeline_with_custom_train_horizons(
self, mock_train_model_pipeline_core, mock_serializer
):
# Arrange
mock_train_model_pipeline_core.return_value = 'a', 'b', 'c', 'd'
mock_train_model_pipeline_core.return_value = "a", "b", "c", "d"
self.pj.train_horizons_minutes = [1440, 21600]
train_horizons_hours = [24, 360]

# Act
train_model_pipeline(
pj=self.pj,
input_data=self.train_input,
check_old_model_age=False,
mlflow_tracking_uri="./test/unit/trained_models/mlruns",
artifact_folder=None
artifact_folder=None,
)

# Assert
assert mock_train_model_pipeline_core.call_args.kwargs["horizons"] == train_horizons_hours

assert (
mock_train_model_pipeline_core.call_args.kwargs["horizons"]
== train_horizons_hours
)

@patch("openstef.pipeline.train_model.MLflowSerializer")
def test_train_model_pipeline_with_save_train_forecasts(self, mock_serializer):
Expand Down

0 comments on commit cdb2a67

Please sign in to comment.