Skip to content

Commit

Permalink
distance names
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthewMiddlehurst committed Dec 5, 2024
1 parent a778d36 commit 1c9f598
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def _path_mask(cost_matrix, path, ax, theme=None): # pragma: no cover
ax.matshow(plot_matrix, cmap=theme)


def _pairwise_path(x, y, metric): # pragma: no cover
pw_matrix = pairwise_distance(x, y, metric=metric)
def _pairwise_path(x, y, method): # pragma: no cover
pw_matrix = pairwise_distance(x, y, method=method)
path = []
for i in range(pw_matrix.shape[0]):
for j in range(pw_matrix.shape[1]):
Expand All @@ -49,7 +49,7 @@ def _pairwise_path(x, y, metric): # pragma: no cover
def _plot_path(
x: np.ndarray,
y: np.ndarray,
metric: str,
method: str,
dist_kwargs: Optional[dict] = None,
title: str = "",
plot_over_pw: bool = False,
Expand All @@ -61,25 +61,25 @@ def _plot_path(
if dist_kwargs is None:
dist_kwargs = {}
try:
path, dist = alignment_path(x, y, metric=metric, **dist_kwargs)
cost_matrix = compute_cost_matrix(x, y, metric=metric, **dist_kwargs)
path, dist = alignment_path(x, y, method=method, **dist_kwargs)
cost_matrix = compute_cost_matrix(x, y, method=method, **dist_kwargs)

if metric == "lcss":
if method == "lcss":
_path = []
for tup in path:
_path.append(tuple(x + 1 for x in tup))
path = _path

if plot_over_pw is True:
if metric == "lcss":
pw = pairwise_distance(x, y, metric="euclidean")
if method == "lcss":
pw = pairwise_distance(x, y, method="euclidean")
cost_matrix = np.zeros_like(cost_matrix)
cost_matrix[1:, 1:] = pw
else:
pw = pairwise_distance(x, y, metric="squared")
pw = pairwise_distance(x, y, method="squared")
cost_matrix = pw
except NotImplementedError:
path, dist, cost_matrix = _pairwise_path(x, y, metric)
path, dist, cost_matrix = _pairwise_path(x, y, method)

plt.figure(1, figsize=(8, 8))
x_size = x.shape[0]
Expand Down Expand Up @@ -119,7 +119,7 @@ def _plot_path(


def _plot_alignment(
x, y, metric, dist_kwargs: Optional[dict] = None, title: str = ""
x, y, method, dist_kwargs: Optional[dict] = None, title: str = ""
): # pragma: no cover
_check_soft_dependencies("matplotlib")

Expand All @@ -128,9 +128,9 @@ def _plot_alignment(
if dist_kwargs is None:
dist_kwargs = {}
try:
path, dist = alignment_path(x, y, metric=metric, **dist_kwargs)
path, dist = alignment_path(x, y, method=method, **dist_kwargs)
except NotImplementedError:
path, dist, cost_matrix = _pairwise_path(x, y, metric)
path, dist, cost_matrix = _pairwise_path(x, y, method)

plt.figure(1, figsize=(8, 8))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@


# used for dtw and wdtw primarily
def _tune_window(metric, train_X, n_clusters): # pragma: no cover
def _tune_window(method, train_X, n_clusters): # pragma: no cover
best_w = 0
best_score = sys.float_info.max
for w in np.arange(0.0, 0.2, 0.01):
cls = TimeSeriesKMeans(
metric=metric, distance_params={"window": w}, n_clusters=n_clusters
distance=method, distance_params={"window": w}, n_clusters=n_clusters
)
cls.fit(train_X)
preds = cls.predict(train_X)
Expand All @@ -35,7 +35,7 @@ def _tune_msm(train_X, n_clusters): # pragma: no cover
best_score = sys.float_info.max
for c in np.arange(0.0, 5.0, 0.25):
cls = TimeSeriesKMeans(
metric="msm", distance_params={"c": c}, n_clusters=n_clusters
distance="msm", distance_params={"c": c}, n_clusters=n_clusters
)
cls.fit(train_X)
preds = cls.predict(train_X)
Expand All @@ -57,7 +57,7 @@ def _tune_wdtw(train_X, n_clusters): # pragma: no cover
best_score = sys.float_info.max
for g in np.arange(0.0, 1.0, 0.05):
cls = TimeSeriesKMeans(
metric="wdtw", distance_params={"g": g}, n_clusters=n_clusters
distance="wdtw", distance_params={"g": g}, n_clusters=n_clusters
)
cls.fit(train_X)
preds = cls.predict(train_X)
Expand All @@ -81,7 +81,7 @@ def _tune_twe(train_X, n_clusters): # pragma: no cover
for nu in np.arange(0.0, 1.0, 0.25):
for lam in np.arange(0.0, 1.0, 0.2):
cls = TimeSeriesKMeans(
metric="twe",
distance="twe",
distance_params={"nu": nu, "lmbda": lam},
n_clusters=n_clusters,
)
Expand Down Expand Up @@ -110,7 +110,7 @@ def _tune_erp(train_X, n_clusters): # pragma: no cover
best_score = sys.float_info.max
for g in np.arange(0.0, 2.0, 0.2):
cls = TimeSeriesKMeans(
metric="erp", distance_params={"g": g}, n_clusters=n_clusters
distance="erp", distance_params={"g": g}, n_clusters=n_clusters
)
cls.fit(train_X)
preds = cls.predict(train_X)
Expand All @@ -132,7 +132,7 @@ def _tune_edr(train_X, n_clusters): # pragma: no cover
best_score = sys.float_info.max
for e in np.arange(0.0, 0.2, 0.01):
cls = TimeSeriesKMeans(
metric="edr", distance_params={"epsilon": e}, n_clusters=n_clusters
distance="edr", distance_params={"epsilon": e}, n_clusters=n_clusters
)
cls.fit(train_X)
preds = cls.predict(train_X)
Expand All @@ -156,7 +156,7 @@ def _tune_lcss(train_X, n_clusters): # pragma: no cover
best_score = sys.float_info.max
for e in np.arange(0.0, 0.2, 0.01):
cls = TimeSeriesKMeans(
metric="lcss", distance_params={"epsilon": e}, n_clusters=n_clusters
distance="lcss", distance_params={"epsilon": e}, n_clusters=n_clusters
)
cls.fit(train_X)
preds = cls.predict(train_X)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,16 @@
"# Find the distance between time series\n",
"from aeon.distances import distance\n",
"\n",
"dtw = distance(x, y, metric=\"dtw\")\n",
"dtw_windowed = distance(x, y, metric=\"dtw\", window=0.2)\n",
"edr = distance(x, y, metric=\"edr\")\n",
"erp = distance(x, y, metric=\"erp\", epsilon=1.0)\n",
"lcss = distance(x, y, metric=\"lcss\")\n",
"msm = distance(x, y, metric=\"msm\")\n",
"twe = distance(x, y, metric=\"twe\")\n",
"wdtw = distance(x, y, metric=\"wdtw\")\n",
"wdtw_g2 = distance(x, y, metric=\"wdtw\", g=0.2)\n",
"wdtw_g3 = distance(x, y, metric=\"wdtw\", g=0.3)"
"dtw = distance(x, y, method=\"dtw\")\n",
"dtw_windowed = distance(x, y, method=\"dtw\", window=0.2)\n",
"edr = distance(x, y, method=\"edr\")\n",
"erp = distance(x, y, method=\"erp\", epsilon=1.0)\n",
"lcss = distance(x, y, method=\"lcss\")\n",
"msm = distance(x, y, method=\"msm\")\n",
"twe = distance(x, y, method=\"twe\")\n",
"wdtw = distance(x, y, method=\"wdtw\")\n",
"wdtw_g2 = distance(x, y, method=\"wdtw\", g=0.2)\n",
"wdtw_g3 = distance(x, y, method=\"wdtw\", g=0.3)"
],
"metadata": {
"collapsed": false,
Expand All @@ -102,16 +102,16 @@
"# Generate the path for each distance\n",
"from aeon.distances import alignment_path\n",
"\n",
"dtw_path = alignment_path(x, y, metric=\"dtw\")\n",
"dtw_path_windowed = alignment_path(x, y, metric=\"dtw\", window=0.2)\n",
"edr_path = alignment_path(x, y, metric=\"edr\")\n",
"erp_path = alignment_path(x, y, metric=\"erp\", epsilon=1.0)\n",
"lcss_path = alignment_path(x, y, metric=\"lcss\")\n",
"msm_path = alignment_path(x, y, metric=\"msm\")\n",
"twe_path = alignment_path(x, y, metric=\"twe\")\n",
"wdtw_path = alignment_path(x, y, metric=\"wdtw\")\n",
"wdtw_path_g2 = alignment_path(x, y, metric=\"wdtw\", g=0.2)\n",
"wdtw_path_g3 = alignment_path(x, y, metric=\"wdtw\", g=0.3)"
"dtw_path = alignment_path(x, y, method=\"dtw\")\n",
"dtw_path_windowed = alignment_path(x, y, method=\"dtw\", window=0.2)\n",
"edr_path = alignment_path(x, y, method=\"edr\")\n",
"erp_path = alignment_path(x, y, method=\"erp\", epsilon=1.0)\n",
"lcss_path = alignment_path(x, y, method=\"lcss\")\n",
"msm_path = alignment_path(x, y, method=\"msm\")\n",
"twe_path = alignment_path(x, y, method=\"twe\")\n",
"wdtw_path = alignment_path(x, y, method=\"wdtw\")\n",
"wdtw_path_g2 = alignment_path(x, y, method=\"wdtw\", g=0.2)\n",
"wdtw_path_g3 = alignment_path(x, y, method=\"wdtw\", g=0.3)"
],
"metadata": {
"collapsed": false,
Expand Down Expand Up @@ -152,9 +152,9 @@
" _plot_path,\n",
")\n",
"\n",
"plt_path = _plot_path(x, y, metric=\"dtw\", title=\"dtw path window 0.2\")\n",
"plt_path = _plot_path(x, y, method=\"dtw\", title=\"dtw path window 0.2\")\n",
"plt_path.show()\n",
"plt_alignment = _plot_alignment(x, y, metric=\"dtw\", title=\"dtw alignment window 0.2\")\n",
"plt_alignment = _plot_alignment(x, y, method=\"dtw\", title=\"dtw alignment window 0.2\")\n",
"plt_alignment.show()"
],
"metadata": {
Expand Down Expand Up @@ -191,11 +191,11 @@
],
"source": [
"plt_path = _plot_path(\n",
" x, y, metric=\"dtw\", title=\"dtw path window 0.2\", dist_kwargs={\"window\": 0.2}\n",
" x, y, method=\"dtw\", title=\"dtw path window 0.2\", dist_kwargs={\"window\": 0.2}\n",
")\n",
"plt_path.show()\n",
"plt_alignment = _plot_alignment(\n",
" x, y, metric=\"dtw\", title=\"dtw alignment window 0.2\", dist_kwargs={\"window\": 0.2}\n",
" x, y, method=\"dtw\", title=\"dtw alignment window 0.2\", dist_kwargs={\"window\": 0.2}\n",
")\n",
"plt_alignment.show()"
],
Expand Down Expand Up @@ -233,11 +233,11 @@
],
"source": [
"plt_path = _plot_path(\n",
" x, y, metric=\"edr\", title=\"edr path\", dist_kwargs={\"epsilon\": 1.0}\n",
" x, y, method=\"edr\", title=\"edr path\", dist_kwargs={\"epsilon\": 1.0}\n",
")\n",
"plt_path.show()\n",
"plt_alignment = _plot_alignment(\n",
" x, y, metric=\"edr\", title=\"edr alignment\", dist_kwargs={\"epsilon\": 0.2}\n",
" x, y, method=\"edr\", title=\"edr alignment\", dist_kwargs={\"epsilon\": 0.2}\n",
")\n",
"plt_alignment.show()"
],
Expand Down Expand Up @@ -274,9 +274,9 @@
}
],
"source": [
"plt_path = _plot_path(x, y, metric=\"erp\", title=\"erp path\")\n",
"plt_path = _plot_path(x, y, method=\"erp\", title=\"erp path\")\n",
"plt_path.show()\n",
"plt_alignment = _plot_alignment(x, y, metric=\"erp\", title=\"erp alignment\")\n",
"plt_alignment = _plot_alignment(x, y, method=\"erp\", title=\"erp alignment\")\n",
"plt_alignment.show()"
],
"metadata": {
Expand Down Expand Up @@ -312,9 +312,9 @@
}
],
"source": [
"plt_path = _plot_path(x, y, metric=\"lcss\", title=\"lcss path\")\n",
"plt_path = _plot_path(x, y, method=\"lcss\", title=\"lcss path\")\n",
"plt_path.show()\n",
"plt_alignment = _plot_alignment(x, y, metric=\"lcss\", title=\"lcss alignment\")\n",
"plt_alignment = _plot_alignment(x, y, method=\"lcss\", title=\"lcss alignment\")\n",
"plt_alignment.show()"
],
"metadata": {
Expand Down Expand Up @@ -350,9 +350,9 @@
}
],
"source": [
"plt_path = _plot_path(x, y, metric=\"msm\", title=\"msm path\")\n",
"plt_path = _plot_path(x, y, method=\"msm\", title=\"msm path\")\n",
"plt_path.show()\n",
"plt_alignment = _plot_alignment(x, y, metric=\"msm\", title=\"msm alignment\")\n",
"plt_alignment = _plot_alignment(x, y, method=\"msm\", title=\"msm alignment\")\n",
"plt_alignment.show()"
],
"metadata": {
Expand Down Expand Up @@ -388,9 +388,9 @@
}
],
"source": [
"plt_path = _plot_path(x, y, metric=\"twe\", title=\"twe path\")\n",
"plt_path = _plot_path(x, y, method=\"twe\", title=\"twe path\")\n",
"plt_path.show()\n",
"plt_alignment = _plot_alignment(x, y, metric=\"twe\", title=\"twe alignment\")\n",
"plt_alignment = _plot_alignment(x, y, method=\"twe\", title=\"twe alignment\")\n",
"plt_alignment.show()"
],
"metadata": {
Expand Down Expand Up @@ -426,9 +426,9 @@
}
],
"source": [
"plt_path = _plot_path(x, y, metric=\"wdtw\", title=\"wdtw path\")\n",
"plt_path = _plot_path(x, y, method=\"wdtw\", title=\"wdtw path\")\n",
"plt_path.show()\n",
"plt_alignment = _plot_alignment(x, y, metric=\"wdtw\", title=\"wdtw alignment\")\n",
"plt_alignment = _plot_alignment(x, y, method=\"wdtw\", title=\"wdtw alignment\")\n",
"plt_alignment.show()"
],
"metadata": {
Expand Down Expand Up @@ -465,11 +465,11 @@
],
"source": [
"plt_path = _plot_path(\n",
" x, y, metric=\"wdtw\", title=\"wdtw path g 0.2\", dist_kwargs={\"g\": 0.2}\n",
" x, y, method=\"wdtw\", title=\"wdtw path g 0.2\", dist_kwargs={\"g\": 0.2}\n",
")\n",
"plt_path.show()\n",
"plt_alignment = _plot_alignment(\n",
" x, y, metric=\"wdtw\", title=\"wdtw alignment g 0.2\", dist_kwargs={\"g\": 0.2}\n",
" x, y, method=\"wdtw\", title=\"wdtw alignment g 0.2\", dist_kwargs={\"g\": 0.2}\n",
")\n",
"plt_alignment.show()"
],
Expand Down Expand Up @@ -507,11 +507,11 @@
],
"source": [
"plt_path = _plot_path(\n",
" x, y, metric=\"wdtw\", title=\"wdtw path g 0.3\", dist_kwargs={\"g\": 0.3}\n",
" x, y, method=\"wdtw\", title=\"wdtw path g 0.3\", dist_kwargs={\"g\": 0.3}\n",
")\n",
"plt_path.show()\n",
"plt_alignment = _plot_alignment(\n",
" x, y, metric=\"wdtw\", title=\"wdtw alignment g 0.3\", dist_kwargs={\"g\": 0.3}\n",
" x, y, method=\"wdtw\", title=\"wdtw alignment g 0.3\", dist_kwargs={\"g\": 0.3}\n",
")\n",
"plt_alignment.show()"
],
Expand Down

0 comments on commit 1c9f598

Please sign in to comment.