forked from unit8co/darts
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Improvement/statsforecastets: make sf_ets probabilistic + add future_…
…covariate support for sf_ets + add AutoTheta (unit8co#1476) * StatsForecastETS now is probabilistic in the same way as StatsForecastAutoARIMA * include future covariates in sf_ets * sf_ets with future_covariates works.. probably it is underestimating the uncertainty because it doesn't take into account the uncertainty of the coef esimation of the OLS * Create separate file for StatsForecast models and extract some functions. * Added AutoTheta from the StatsForecast package. * Deleted sf_auto_arima.py and sf_ets.py, because the code is now included in sf_models.py. * Update darts/models/forecasting/sf_models.py Co-authored-by: Julien Herzen <[email protected]> * Update darts/models/forecasting/sf_models.py Co-authored-by: Julien Herzen <[email protected]> * Moved all statsforecast models to their own .py file. Added some comments explaining the handling of future covariates by StatsForecastETS. Included StatsForecastTheta in the tests. Moved the utility functions that the statsforecast models share to a singly .py file. Added the CES model which is supposed to be probabilistic, but that doesn't work yet eventhough it is supposed to be included in statsforecast 1.4.0. Trying to figure out why it isn't working. Removed sf_models.py. * Beginning of test for fit on residuals for statsforecast ets. * - AutoCES not probablisitc anymore, because that is not yet released in statsforecast 1.4.0 - changed AutoETS to SFAutoETS - added models to the base tests - wrote two units tests for future covariates use for sf_ets * - AutoCES not probablisitc anymore, because that is not yet released in statsforecast 1.4.0 - changed AutoETS to SFAutoETS - added models to the base tests - wrote two units tests for future covariates use for sf_ets * Changed StatsForecastETS to StatsForecastAutoETS. --------- Co-authored-by: Julien Herzen <[email protected]> Co-authored-by: Julien Herzen <[email protected]>
- Loading branch information
1 parent
31528a4
commit e1c8d34
Showing
9 changed files
with
346 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
""" | ||
StatsForecast utils | ||
----------- | ||
""" | ||
|
||
import numpy as np | ||
|
||
# In a normal distribution, 68.27 percentage of values lie within one standard deviation of the mean | ||
one_sigma_rule = 68.27 | ||
|
||
|
||
def create_normal_samples( | ||
mu: float, | ||
std: float, | ||
num_samples: int, | ||
n: int, | ||
) -> np.array: | ||
"""Generate samples assuming a Normal distribution.""" | ||
samples = np.random.normal(loc=mu, scale=std, size=(num_samples, n)).T | ||
samples = np.expand_dims(samples, axis=1) | ||
return samples | ||
|
||
|
||
def unpack_sf_dict( | ||
forecast_dict: dict, | ||
): | ||
"""Unpack the dictionary that is returned by the StatsForecast 'predict()' method.""" | ||
mu = forecast_dict["mean"] | ||
std = forecast_dict[f"hi-{one_sigma_rule}"] - mu | ||
return mu, std |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
""" | ||
StatsForecastAutoCES | ||
----------- | ||
""" | ||
|
||
from statsforecast.models import AutoCES as SFAutoCES | ||
|
||
from darts import TimeSeries | ||
from darts.models.forecasting.forecasting_model import LocalForecastingModel | ||
|
||
|
||
class StatsForecastAutoCES(LocalForecastingModel): | ||
def __init__(self, *autoces_args, **autoces_kwargs): | ||
"""Auto-CES based on `Statsforecasts package | ||
<https://github.com/Nixtla/statsforecast>`_. | ||
Automatically selects the best Complex Exponential Smoothing model using an information criterion. | ||
<https://onlinelibrary.wiley.com/doi/full/10.1002/nav.22074> | ||
We refer to the `statsforecast AutoCES documentation | ||
<https://nixtla.github.io/statsforecast/models.html#autoces>`_ | ||
for the documentation of the arguments. | ||
Parameters | ||
---------- | ||
autoces_args | ||
Positional arguments for ``statsforecasts.models.AutoCES``. | ||
autoces_kwargs | ||
Keyword arguments for ``statsforecasts.models.AutoCES``. | ||
.. | ||
Examples | ||
-------- | ||
>>> from darts.models import StatsForecastAutoCES | ||
>>> from darts.datasets import AirPassengersDataset | ||
>>> series = AirPassengersDataset().load() | ||
>>> model = StatsForecastAutoCES(season_length=12) | ||
>>> model.fit(series[:-36]) | ||
>>> pred = model.predict(36, num_samples=100) | ||
""" | ||
super().__init__() | ||
self.model = SFAutoCES(*autoces_args, **autoces_kwargs) | ||
|
||
def __str__(self): | ||
return "Auto-CES-Statsforecasts" | ||
|
||
def fit(self, series: TimeSeries): | ||
super().fit(series) | ||
self._assert_univariate(series) | ||
series = self.training_series | ||
self.model.fit( | ||
series.values(copy=False).flatten(), | ||
) | ||
return self | ||
|
||
def predict( | ||
self, | ||
n: int, | ||
num_samples: int = 1, | ||
verbose: bool = False, | ||
): | ||
super().predict(n, num_samples) | ||
forecast_dict = self.model.predict( | ||
h=n, | ||
) | ||
|
||
mu = forecast_dict["mean"] | ||
|
||
return self._build_forecast_series(mu) | ||
|
||
@property | ||
def min_train_series_length(self) -> int: | ||
return 10 | ||
|
||
def _supports_range_index(self) -> bool: | ||
return True | ||
|
||
def _is_probabilistic(self) -> bool: | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.