Skip to content

Commit

Permalink
Docs/general improvements (#1904)
Browse files Browse the repository at this point in the history
* udpate dtw example and add dtw to rendered documentation

* make all windows inherit from Window

* clean up windows

* improve dtw documentation

* improved forecasting model module documentation

* update models and add model links in covariates user guide

* add model links to README

* update changelog

* fix typo in dtw example notebook

* remove outdated lines from tide model from before probabilistic support

* apply suggestions from PR review

* update readme model table

* Feat/fit predict encodings (#1925)

* added encode_train_inference to encoders

* added generate_fit_predict_encodings to ForecastingModel

* simplify TransferrableFut..Model.generatice_predict_encodings

* update changelog

* Apply suggestions from code review

Co-authored-by: madtoinou <[email protected]>

* apply suggestions from PR review part 2

---------

Co-authored-by: madtoinou <[email protected]>

* update readme

* apply suggestions from code review and improve README.md

* Update README.md

Co-authored-by: madtoinou <[email protected]>

---------

Co-authored-by: madtoinou <[email protected]>
  • Loading branch information
dennisbader and madtoinou authored Aug 4, 2023
1 parent 1af057a commit 8b88b0d
Show file tree
Hide file tree
Showing 10 changed files with 282 additions and 164 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- Improvements to `Explainability` module:
- 🚀🚀 New forecasting model explainer: `TFTExplainer` for `TFTModel`. You can now access and visualize the trained model's feature importances and self attention. [#1392](https://github.com/unit8co/darts/issues/1392) by [Sebastian Cattes](https://github.com/Cattes) and [Dennis Bader](https://github.com/dennisbader).
- Added static covariates support to `ShapeExplainer`. [#1803](https://github.com/unit8co/darts/pull/1803) by [Anne de Vries](https://github.com/anne-devries) and [Dennis Bader](https://github.com/dennisbader).
- Improvements to documentation [#1904](https://github.com/unit8co/darts/pull/1904) by [Dennis Bader](https://github.com/dennisbader):
- made model sections in README.md, covariates user guide and forecasting model API Reference more user friendly by adding model links and reorganizing them into model categories.
- added the Dynamic Time Warping (DTW) module and improved its appearance.
- Other improvements:
- Improved static covariates column naming when using `StaticCovariatesTransformer` with a `sklearn.preprocessing.OneHotEncoder`. [#1863](https://github.com/unit8co/darts/pull/1863) by [Anne de Vries](https://github.com/anne-devries).
- Added `MSTL` (Season-Trend decomposition using LOESS for multiple seasonalities) as a `method` option for `extract_trend_and_seasonality()`. [#1879](https://github.com/unit8co/darts/pull/1879) by [Alex Colpitts](https://github.com/alexcolpitts96).
Expand Down
85 changes: 47 additions & 38 deletions README.md

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions darts/dataprocessing/dtw/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""
Dynamic Time Warping (DTW)
--------------------------
"""

from .cost_matrix import CostMatrix
from .dtw import DTWAlignment, dtw
from .window import CRWindow, Itakura, NoWindow, SakoeChiba, Window
79 changes: 52 additions & 27 deletions darts/dataprocessing/dtw/dtw.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
"""
Dynamic Time Warping (DTW)
--------------------------
"""

import copy
from typing import Callable, Union
from typing import Callable, Optional, Union

import numpy as np
import pandas as pd
Expand All @@ -21,7 +26,7 @@
# CORE ALGORITHM
def _dtw_cost_matrix(
x: np.ndarray, y: np.ndarray, dist: DistanceFunc, window: Window
) -> np.ndarray:
) -> CostMatrix:

dtw = CostMatrix._from_window(window)

Expand Down Expand Up @@ -138,16 +143,40 @@ def _fast_dtw(
return cost


def _default_distance_multi(x_values: np.ndarray, y_values: np.ndarray):
return np.sum(np.abs(x_values - y_values))


def _default_distance_uni(x_value: float, y_value: float):
return abs(x_value - y_value)


# Public API Functions
class DTWAlignment:
"""
Dynamic Time Warping (DTW) Alignment.
Attributes
----------
n
The length of `series1`
m
The length of `series2`
series1
A `TimeSeries` to align with `series2`.
series2
A `TimeSeries` to align with `series1`.
cost
The `CostMatrix` for DTW.
"""

n: int
m: int
series1: TimeSeries
series2: TimeSeries
cost: CostMatrix

def __init__(self, series1: TimeSeries, series2: TimeSeries, cost: CostMatrix):

self.n = len(series1)
self.m = len(series2)
self.series1 = series1
Expand All @@ -157,7 +186,8 @@ def __init__(self, series1: TimeSeries, series2: TimeSeries, cost: CostMatrix):
from ._plot import plot, plot_alignment

def path(self) -> np.ndarray:
"""
"""Gives the index paths from `series1` to `series2`.
Returns
-------
np.ndarray of shape `(len(path), 2)`
Expand All @@ -172,7 +202,8 @@ def path(self) -> np.ndarray:
return self._path

def distance(self) -> float:
"""
"""Gives the total distance between pair-wise elements in the two series after warping.
Returns
-------
float
Expand All @@ -181,7 +212,8 @@ def distance(self) -> float:
return self.cost[(self.n, self.m)]

def mean_distance(self) -> float:
"""
"""Gives the mean distance between pair-wise elements in the two series after warping.
Returns
-------
float
Expand All @@ -195,9 +227,8 @@ def mean_distance(self) -> float:
return self._mean_distance

def warped(self) -> (TimeSeries, TimeSeries):
"""
Warps the two time series according to the warp path returned by .path(), which minimizes
the pair-wise distance.
"""Warps the two time series according to the warp path returned by `DTWAlignment.path()`, which minimizes the
pair-wise distance.
This will bring two time series that are out-of-phase back into phase.
Returns
Expand Down Expand Up @@ -254,24 +285,16 @@ def warped(self) -> (TimeSeries, TimeSeries):
)


def default_distance_multi(x_values: np.ndarray, y_values: np.ndarray):
return np.sum(np.abs(x_values - y_values))


def default_distance_uni(x_value: float, y_value: float):
return abs(x_value - y_value)


def dtw(
series1: TimeSeries,
series2: TimeSeries,
window: Window = NoWindow(),
window: Optional[Window] = None,
distance: Union[DistanceFunc, None] = None,
multi_grid_radius: int = -1,
) -> DTWAlignment:
"""
Determines the optimal alignment between two time series series1 and series2,
according to the Dynamic Time Warping algorithm.
Determines the optimal alignment between two time series `series1` and `series2`, according to the Dynamic Time
Warping algorithm.
The alignment minimizes the distance between pair-wise elements after warping.
All elements in the two series are matched and are in strictly monotonically increasing order.
Considers only the values in the series, ignoring the time axis.
Expand All @@ -282,24 +305,24 @@ def dtw(
Parameters
----------
series1
`TimeSeries`
A `TimeSeries` to align with `series2`.
series2
A `TimeSeries`
A `TimeSeries` to align with `series1`.
window
Used to constrain the search for the optimal alignment: see SakoeChiba and Itakura.
Default considers all possible alignments.
Optionally, a `Window` used to constrain the search for the optimal alignment: see `SakoeChiba` and `Itakura`.
Default considers all possible alignments (`NoWindow`).
distance
Function taking as input either two `floats` for univariate series or two `np.ndarray`,
and returning the distance between them.
Defaults to the abs difference for univariate-data and the
sum of the abs difference for multi-variate series.
multi_grid_radius
Default radius of -1 results in an exact evaluation of the dynamic time warping algorithm.
Default radius of `-1` results in an exact evaluation of the dynamic time warping algorithm.
Without constraints DTW runs in O(nxm) time where n,m are the size of the series.
Exact evaluation with no constraints, will result in a performance warning on large datasets.
Setting multi_grid_radius to a value other than -1, will enable the approximate multi-grid solver,
Setting `multi_grid_radius` to a value other than `-1`, will enable the approximate multi-grid solver,
which executes in linear time, vs quadratic time for exact evaluation.
Increasing radius trades solution accuracy for performance.
Expand All @@ -308,6 +331,8 @@ def dtw(
DTWAlignment
Helper object for getting warp path, mean_distance, distance and warped time series
"""
if window is None:
window = NoWindow()

if (
multi_grid_radius == -1
Expand All @@ -328,7 +353,7 @@ def dtw(
logger,
)

distance = default_distance_uni if both_univariate else default_distance_multi
distance = _default_distance_uni if both_univariate else _default_distance_multi

if both_univariate:
values_x = series1.univariate_values(copy=False)
Expand Down
Loading

0 comments on commit 8b88b0d

Please sign in to comment.