-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #120 from AlpacaDB/feature/ds3289-marketcloseaware
Add 3 New Ternary-Class Labelizers
- Loading branch information
Showing
6 changed files
with
347 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import pandas as pd | ||
|
||
from backlight.datasource.marketdata import MarketData | ||
from backlight.labelizer.common import LabelType, TernaryDirection | ||
from backlight.labelizer.labelizer import Label | ||
from backlight.labelizer.ternary.static_neutral import StaticNeutralLabelizer | ||
from backlight.labelizer.ternary.dynamic_neutral import ( | ||
MarketCloseAwareDynamicNeutralLabelizer, | ||
) | ||
|
||
|
||
class HybridNeutralLabelizer( | ||
StaticNeutralLabelizer, MarketCloseAwareDynamicNeutralLabelizer | ||
): | ||
def __init__(self, **kwargs: str) -> None: | ||
super().__init__(**kwargs) | ||
self.validate_params() | ||
|
||
def validate_params(self) -> None: | ||
super(HybridNeutralLabelizer, self).validate_params() | ||
super(MarketCloseAwareDynamicNeutralLabelizer, self).validate_params() | ||
assert "alpha" in self._params | ||
assert 0 <= float(self._params["alpha"]) <= 1 | ||
|
||
def _calculate_hybrid_neutral_range(self, diff_abs: pd.Series) -> pd.Series: | ||
snr = self._calculate_static_neutral_range(diff_abs) | ||
dnr = self._calculate_dynamic_neutral_range(diff_abs) | ||
return self.alpha * snr + (1 - self.alpha) * dnr | ||
|
||
def create(self, mkt: MarketData) -> pd.DataFrame: | ||
mid = mkt.mid.copy() | ||
future_price = mid.shift(freq="-{}".format(self._params["lookahead"])) | ||
diff = (future_price - mid).reindex(mid.index) | ||
diff_abs = diff.abs() | ||
neutral_range = self._calculate_hybrid_neutral_range(diff_abs) | ||
df = mid.to_frame("mid") | ||
df.loc[:, "label_diff"] = diff | ||
df.loc[:, "neutral_range"] = neutral_range | ||
df.loc[df.label_diff > 0, "label"] = TernaryDirection.UP.value | ||
df.loc[df.label_diff < 0, "label"] = TernaryDirection.DOWN.value | ||
df.loc[diff_abs < neutral_range, "label"] = TernaryDirection.NEUTRAL.value | ||
df = Label(df[["label_diff", "label", "neutral_range"]]) | ||
df.label_type = LabelType.TERNARY | ||
return df | ||
|
||
@property | ||
def alpha(self) -> float: | ||
return float(self._params["alpha"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import pandas as pd | ||
import numpy as np | ||
|
||
from backlight.datasource.marketdata import MarketData | ||
from backlight.labelizer.common import LabelType, TernaryDirection | ||
from backlight.labelizer.labelizer import Labelizer, Label | ||
|
||
|
||
class StaticNeutralLabelizer(Labelizer): | ||
"""Generates session-aware static labels | ||
Args: | ||
lookahead (str): Lookahead period | ||
session_splits (list[datetime.time]): EST local time to split sessions | ||
neutral_ratio (float): 0 < x < 1, Percentage of NEUTRAL labels | ||
window_start (str): Start date for lookback window | ||
window_end (str): End date for lookback window | ||
neutral_hard_limit (float): The minimum diff to label UP/DOWN | ||
""" | ||
|
||
def validate_params(self) -> None: | ||
assert "lookahead" in self._params | ||
assert "session_splits" in self._params | ||
assert len(self._params["session_splits"]) | ||
assert "neutral_ratio" in self._params | ||
assert "window_start" in self._params | ||
assert "window_end" in self._params | ||
assert "neutral_hard_limit" in self._params | ||
|
||
def _calculate_static_neutral_range(self, diff_abs: pd.Series) -> pd.Series: | ||
df = pd.DataFrame(diff_abs.values, index=diff_abs.index, columns=["diff"]) | ||
df.loc[:, "nyk_time"] = df.index.tz_convert("America/New_York") | ||
df.loc[:, "res"] = np.nan | ||
|
||
mask = ( | ||
(df.index >= self._params["window_start"]) | ||
& (df.index < self._params["window_end"]) | ||
& ~((df.nyk_time.dt.hour <= 17) & (df.nyk_time.dt.dayofweek == 6)) | ||
& ((df.nyk_time.dt.hour < 16) | (df.nyk_time.dt.hour > 17)) | ||
& ~((df.nyk_time.dt.hour >= 16) & (df.nyk_time.dt.dayofweek == 4)) | ||
& (df.nyk_time.dt.dayofweek != 5) | ||
) | ||
|
||
splits = sorted(self._params["session_splits"]) | ||
shifted_splits = splits[1:] + splits[:1] | ||
|
||
for s, t in list(zip(splits, shifted_splits)): | ||
if s >= t: | ||
scope = (df.nyk_time.dt.time >= s) | (df.nyk_time.dt.time < t) | ||
else: | ||
scope = (df.nyk_time.dt.time >= s) & (df.nyk_time.dt.time < t) | ||
df.loc[scope, "res"] = df.loc[(scope & mask), "diff"].quantile( | ||
self.neutral_ratio | ||
) | ||
|
||
df.loc[(df.res < self.neutral_hard_limit), "res"] = self.neutral_hard_limit | ||
|
||
return df.res | ||
|
||
def create(self, mkt: MarketData) -> pd.DataFrame: | ||
mid = mkt.mid.copy() | ||
future_price = mid.shift(freq="-{}".format(self._params["lookahead"])) | ||
diff = (future_price - mid).reindex(mid.index) | ||
diff_abs = diff.abs() | ||
neutral_range = self._calculate_static_neutral_range(diff_abs) | ||
df = mid.to_frame("mid") | ||
df.loc[:, "label_diff"] = diff | ||
df.loc[:, "neutral_range"] = neutral_range | ||
df.loc[df.label_diff > 0, "label"] = TernaryDirection.UP.value | ||
df.loc[df.label_diff < 0, "label"] = TernaryDirection.DOWN.value | ||
df.loc[diff_abs < neutral_range, "label"] = TernaryDirection.NEUTRAL.value | ||
df = Label(df[["label_diff", "label", "neutral_range"]]) | ||
df.label_type = LabelType.TERNARY | ||
return df | ||
|
||
@property | ||
def neutral_ratio(self) -> str: | ||
return self._params["neutral_ratio"] | ||
|
||
@property | ||
def session_splits(self) -> str: | ||
return self._params["session_splits"] | ||
|
||
@property | ||
def neutral_hard_limit(self) -> str: | ||
return self._params["neutral_hard_limit"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from backlight.labelizer.ternary.dynamic_neutral import ( | ||
MarketCloseAwareDynamicNeutralLabelizer as module, | ||
) | ||
|
||
import pytest | ||
import pandas as pd | ||
import numpy as np | ||
|
||
|
||
@pytest.fixture | ||
def sample_df(): | ||
index = pd.date_range( | ||
"2017-09-04 13:00:00+00:00", "2017-09-05 13:00:00+00:00", freq="1H" | ||
) | ||
return pd.DataFrame( | ||
index=index, | ||
data=np.array( | ||
[ | ||
[109.68, 109.69, 109.685], | ||
[109.585, 109.595, 109.59], | ||
[109.525, 109.535, 109.53], | ||
[109.6, 109.61, 109.605], | ||
[109.695, 109.7, 109.6975], | ||
[109.565, 109.705, 109.635], | ||
[109.63, 109.685, 109.6575], | ||
[109.555, 109.675, 109.615], | ||
[109.7, 109.75, 109.725], | ||
[109.67, 109.72, 109.695], | ||
[109.66, 109.675, 109.6675], | ||
[109.8, 109.815, 109.8075], | ||
[109.565, 109.575, 109.57], | ||
[109.535, 109.545, 109.54], | ||
[109.32, 109.33, 109.325], | ||
[109.27, 109.275, 109.2725], | ||
[109.345, 109.355, 109.35], | ||
[109.305, 109.315, 109.31], | ||
[109.3, 109.31, 109.305], | ||
[109.445, 109.46, 109.4525], | ||
[109.42, 109.425, 109.4225], | ||
[109.385, 109.395, 109.39], | ||
[109.305, 109.315, 109.31], | ||
[109.365, 109.375, 109.37], | ||
[109.365, 109.375, 109.37], | ||
] | ||
), | ||
columns=["bid", "ask", "mid"], | ||
) | ||
|
||
|
||
def test_create(sample_df): | ||
lbl_args = { | ||
"lookahead": "1H", | ||
"neutral_ratio": 0.5, | ||
"neutral_window": "3H", | ||
"neutral_hard_limit": 0.00, | ||
} | ||
lbl = module(**lbl_args).create(sample_df) | ||
assert lbl.label.sum() == -3 | ||
assert lbl.neutral_range.isna().sum() == 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from backlight.labelizer.ternary.hybrid_neutral import HybridNeutralLabelizer as module | ||
|
||
import pytest | ||
import pandas as pd | ||
import numpy as np | ||
import datetime | ||
|
||
|
||
@pytest.fixture | ||
def sample_df(): | ||
index = pd.date_range( | ||
"2017-09-04 13:00:00+00:00", "2017-09-05 13:00:00+00:00", freq="1H" | ||
) | ||
return pd.DataFrame( | ||
index=index, | ||
data=np.array( | ||
[ | ||
[109.68, 109.69, 109.685], | ||
[109.585, 109.595, 109.59], | ||
[109.525, 109.535, 109.53], | ||
[109.6, 109.61, 109.605], | ||
[109.695, 109.7, 109.6975], | ||
[109.565, 109.705, 109.635], | ||
[109.63, 109.685, 109.6575], | ||
[109.555, 109.675, 109.615], | ||
[109.7, 109.75, 109.725], | ||
[109.67, 109.72, 109.695], | ||
[109.66, 109.675, 109.6675], | ||
[109.8, 109.815, 109.8075], | ||
[109.565, 109.575, 109.57], | ||
[109.535, 109.545, 109.54], | ||
[109.32, 109.33, 109.325], | ||
[109.27, 109.275, 109.2725], | ||
[109.345, 109.355, 109.35], | ||
[109.305, 109.315, 109.31], | ||
[109.3, 109.31, 109.305], | ||
[109.445, 109.46, 109.4525], | ||
[109.42, 109.425, 109.4225], | ||
[109.385, 109.395, 109.39], | ||
[109.305, 109.315, 109.31], | ||
[109.365, 109.375, 109.37], | ||
[109.365, 109.375, 109.37], | ||
] | ||
), | ||
columns=["bid", "ask", "mid"], | ||
) | ||
|
||
|
||
def test_create(sample_df): | ||
lbl_args = { | ||
"lookahead": "1H", | ||
"neutral_ratio": 0.5, | ||
"session_splits": [datetime.time(9), datetime.time(18)], | ||
"neutral_window": "3H", | ||
"neutral_hard_limit": 0.00, | ||
"window_start": "20170904 12:00:00+0000", | ||
"window_end": "20170905 06:00:00+0000", | ||
"alpha": 0.5, | ||
} | ||
|
||
lbl = module(**lbl_args).create(sample_df) | ||
assert lbl.label.sum() == 1 | ||
assert lbl.neutral_range.isna().sum() == 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from backlight.labelizer.ternary.static_neutral import StaticNeutralLabelizer as module | ||
|
||
import pytest | ||
import pandas as pd | ||
import numpy as np | ||
import datetime | ||
|
||
|
||
@pytest.fixture | ||
def sample_df(): | ||
index = pd.date_range( | ||
"2017-09-04 13:00:00+00:00", "2017-09-05 13:00:00+00:00", freq="1H" | ||
) | ||
return pd.DataFrame( | ||
index=index, | ||
data=np.array( | ||
[ | ||
[109.68, 109.69, 109.685], | ||
[109.585, 109.595, 109.59], | ||
[109.525, 109.535, 109.53], | ||
[109.6, 109.61, 109.605], | ||
[109.695, 109.7, 109.6975], | ||
[109.565, 109.705, 109.635], | ||
[109.63, 109.685, 109.6575], | ||
[109.555, 109.675, 109.615], | ||
[109.7, 109.75, 109.725], | ||
[109.67, 109.72, 109.695], | ||
[109.66, 109.675, 109.6675], | ||
[109.8, 109.815, 109.8075], | ||
[109.565, 109.575, 109.57], | ||
[109.535, 109.545, 109.54], | ||
[109.32, 109.33, 109.325], | ||
[109.27, 109.275, 109.2725], | ||
[109.345, 109.355, 109.35], | ||
[109.305, 109.315, 109.31], | ||
[109.3, 109.31, 109.305], | ||
[109.445, 109.46, 109.4525], | ||
[109.42, 109.425, 109.4225], | ||
[109.385, 109.395, 109.39], | ||
[109.305, 109.315, 109.31], | ||
[109.365, 109.375, 109.37], | ||
[109.365, 109.375, 109.37], | ||
] | ||
), | ||
columns=["bid", "ask", "mid"], | ||
) | ||
|
||
|
||
def test_create(sample_df): | ||
lbl_args = { | ||
"lookahead": "1H", | ||
"neutral_ratio": 0.5, | ||
"session_splits": [datetime.time(9), datetime.time(18)], | ||
"neutral_hard_limit": 0.00, | ||
"window_start": "20170904 12:00:00+0000", | ||
"window_end": "20170905 06:00:00+0000", | ||
} | ||
|
||
lbl = module(**lbl_args).create(sample_df) | ||
assert lbl.label.sum() == 1 | ||
assert lbl.neutral_range.isna().sum() == 0 |