diff --git a/CHANGELOG.md b/CHANGELOG.md index cdd99888fa..9f2f48a3b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co **Fixed** - Fixed a bug when calling optimized `historical_forecasts()` for a `RegressionModel` trained with unequal component-specific lags. [#2040](https://github.com/unit8co/darts/pull/2040) by [Antoine Madrona](https://github.com/madtoinou). - Fixed a bug when using encoders with `RegressionModel` and series with a non-evenly spaced frequency (e.g. Month Begin). This raised an error during lagged data creation when trying to divide a pd.Timedelta by the ambiguous frequency. [#2034](https://github.com/unit8co/darts/pull/2034) by [Antoine Madrona](https://github.com/madtoinou). +- Fixed a bug when loading a `TorchForecastingModel` that was trained with a precision other than `float64`. [#2046](https://github.com/unit8co/darts/pull/2046) by [Freddie Hsin-Fu Huang](https://github.com/Hsinfu). ### For developers of the library: diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index 8312808689..28cc5c8d0f 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -107,6 +107,12 @@ RUNS_FOLDER = "runs" INIT_MODEL_NAME = "_model.pth.tar" +TORCH_NP_DTYPES = { + torch.float16: np.float16, + torch.float32: np.float32, + torch.float64: np.float64, +} + # pickling a TorchForecastingModel will not save below attributes: the keys specify the # attributes to be ignored, and the values are the default values getting assigned upon loading TFM_ATTRS_NO_PICKLE = {"model": None, "trainer": None} @@ -1872,8 +1878,9 @@ def load_weights_from_checkpoint( ) # pl_forecasting module saves the train_sample shape, must recreate one + np_dtype = TORCH_NP_DTYPES[ckpt["model_dtype"]] mock_train_sample = [ - np.zeros(sample_shape) if sample_shape else None + np.zeros(sample_shape, dtype=np_dtype) if sample_shape else None for sample_shape in ckpt["train_sample_shape"] ] self.train_sample = tuple(mock_train_sample) diff --git a/darts/tests/models/forecasting/test_torch_forecasting_model.py b/darts/tests/models/forecasting/test_torch_forecasting_model.py index 400899b76a..6a150bbfa3 100644 --- a/darts/tests/models/forecasting/test_torch_forecasting_model.py +++ b/darts/tests/models/forecasting/test_torch_forecasting_model.py @@ -1045,6 +1045,29 @@ def test_load_weights(self, tmpdir_fn): f"respectively {retrained_mape} and {original_mape}" ) + def test_load_weights_with_float32_dtype(self, tmpdir_fn): + ts_float32 = self.series.astype("float32") + model_name = "test_model" + ckpt_path = os.path.join(tmpdir_fn, f"{model_name}.pt") + # barebone model + model = DLinearModel( + input_chunk_length=4, + output_chunk_length=1, + n_epochs=1, + ) + model.fit(ts_float32) + model.save(ckpt_path) + assert model.model._dtype == torch.float32 # type: ignore + + # identical model + loading_model = DLinearModel( + input_chunk_length=4, + output_chunk_length=1, + ) + loading_model.load_weights(ckpt_path) + loading_model.fit(ts_float32) + assert loading_model.model._dtype == torch.float32 # type: ignore + def test_multi_steps_pipeline(self, tmpdir_fn): ts_training, ts_val = self.series.split_before(75) pretrain_model_name = "pre-train"