-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #49 from namoshizun/test-cases-part-3
* Test cases for: Blacklist, Backtester, Optimization, * `StrategyConf.strategy_class`: the value can be the class itself (no need to always save the strategy class code to a importable path)
- Loading branch information
Showing
25 changed files
with
755 additions
and
241 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,3 +3,5 @@ omit = | |
tradepy/cli/*, | ||
tradepy/trade_cal.py | ||
source = tradepy | ||
# multiprocessing = True | ||
# parallel = true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
import io | ||
import pytest | ||
import pandas as pd | ||
from unittest import mock | ||
|
||
from tradepy.trade_book.trade_book import TradeBook | ||
from tradepy.core.conf import BacktestConf, StrategyConf, SlippageConf, SL_TP_Order | ||
from tradepy.backtest.backtester import Backtester | ||
from .conftest import SampleBacktestStrategy | ||
|
||
|
||
@pytest.fixture | ||
def backtest_conf(): | ||
return BacktestConf( | ||
cash_amount=1e6, | ||
broker_commission_rate=0.01, | ||
min_broker_commission_fee=0, | ||
use_minute_k=False, | ||
sl_tf_order="stop loss first", | ||
strategy=StrategyConf( | ||
stop_loss=3, | ||
take_profit=4, | ||
take_profit_slip=SlippageConf(method="max_jump", params=1), | ||
stop_loss_slip=SlippageConf(method="max_pct", params=0.1), | ||
max_position_opens=10, | ||
max_position_size=0.25, | ||
min_trade_amount=8000, | ||
), # type: ignore | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def sample_backtester(backtest_conf: BacktestConf): | ||
return Backtester(backtest_conf) | ||
|
||
|
||
@pytest.fixture | ||
def sample_strategy(backtest_conf: BacktestConf): | ||
return SampleBacktestStrategy(backtest_conf.strategy) | ||
|
||
|
||
@pytest.fixture | ||
def sample_computed_day_k_df( | ||
local_stocks_day_k_df: pd.DataFrame, | ||
sample_strategy: SampleBacktestStrategy, | ||
request: pytest.FixtureRequest, | ||
): | ||
cache_key = "cached_computed_day_k" | ||
if raw_csv_str := request.config.cache.get(cache_key, None): | ||
return pd.read_csv(io.StringIO(raw_csv_str), dtype={"code": str}) | ||
|
||
result_df: pd.DataFrame = sample_strategy.compute_all_indicators_df( | ||
local_stocks_day_k_df | ||
) | ||
buf = io.StringIO() | ||
result_df.reset_index(drop=True).to_csv(buf) | ||
buf.seek(0) | ||
request.config.cache.set(cache_key, buf.read()) | ||
return result_df | ||
|
||
|
||
@pytest.fixture | ||
def sample_stock_data(): | ||
_ = None | ||
return pd.DataFrame( | ||
[ | ||
# code, close, sma5, boll_lower, boll_upper, vol, vol_ref1, comment | ||
["000001", 10, 10, 11, _, _, _, "buy: close <= boll_lower"], | ||
["000002", 11, 10, _, _, 100, 90, "buy: sma5 breakthrough"], | ||
["000003", 12, 15, 10, _, _, _, "noop"], | ||
["000004", 13, _, _, 10, _, _, "sell: close >= boll_upper"], | ||
], | ||
columns=[ | ||
"code", | ||
"close", | ||
"sma5", | ||
"boll_lower", | ||
"boll_upper", | ||
"vol", | ||
"vol_ref1", | ||
"comment", | ||
], | ||
).set_index("code") | ||
|
||
|
||
def test_get_buy_options( | ||
sample_backtester: Backtester, | ||
sample_strategy: SampleBacktestStrategy, | ||
sample_stock_data: pd.DataFrame, | ||
): | ||
selector = sample_stock_data["comment"].str.contains("buy") | ||
expect_buy_stocks = sample_stock_data[selector].index | ||
|
||
# Should buy the stocks that trigger the buy signals | ||
buy_options = sample_backtester.get_buy_options(sample_stock_data, sample_strategy) | ||
assert set(buy_options.index) == set(expect_buy_stocks) | ||
|
||
# Check the returned order prices | ||
for code, option in buy_options.iterrows(): | ||
assert option["order_price"] == sample_stock_data.loc[code, "close"] | ||
|
||
|
||
def test_get_close_signals( | ||
sample_backtester: Backtester, | ||
sample_strategy: SampleBacktestStrategy, | ||
sample_stock_data: pd.DataFrame, | ||
): | ||
selector = sample_stock_data["comment"].str.contains("sell") | ||
expect_sell_stocks = sample_stock_data[selector].index | ||
|
||
# Should buy the stocks that trigger the buy signals | ||
with mock.patch( | ||
"tradepy.core.holdings.Holdings.position_codes", | ||
) as mock_position_codes: | ||
mock_position_codes.__get__ = mock.Mock( | ||
return_value=sample_stock_data.index.tolist() | ||
) | ||
sell_codes = sample_backtester.get_close_signals( | ||
sample_stock_data, sample_strategy | ||
) | ||
assert set(sell_codes) == set(expect_sell_stocks) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"sl_tf_order", | ||
["stop loss first", "take profit first", "random"], | ||
) | ||
def test_day_k_trading( | ||
sl_tf_order: SL_TP_Order, | ||
sample_computed_day_k_df: pd.DataFrame, | ||
sample_strategy: SampleBacktestStrategy, | ||
backtest_conf: BacktestConf, | ||
): | ||
backtest_conf.sl_tf_order = sl_tf_order | ||
backtester = Backtester(backtest_conf) | ||
trade_book = backtester.trade(sample_computed_day_k_df, sample_strategy) | ||
|
||
assert isinstance(trade_book, TradeBook) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from unittest import mock | ||
import pytest | ||
import tempfile | ||
import pandas as pd | ||
from pathlib import Path | ||
|
||
import tradepy | ||
from tradepy.blacklist import Blacklist | ||
|
||
|
||
@pytest.yield_fixture(autouse=True) | ||
def temp_blacklist_file(): | ||
with tempfile.NamedTemporaryFile("w") as f: | ||
stocks = pd.DataFrame( | ||
[ | ||
["000001", "2021-01-01"], | ||
["000002", "2030-01-02"], | ||
["000003", None], | ||
], | ||
columns=["code", "until"], | ||
).set_index("code") | ||
stocks.to_csv(f.name) | ||
|
||
with mock.patch("tradepy.config.common.blacklist_path", Path(f.name)): | ||
yield f.name | ||
Blacklist.purge_cache() | ||
|
||
|
||
def test_contains_when_blacklist_is_empty(): | ||
with mock.patch("tradepy.config.common.blacklist_path", None): | ||
assert not Blacklist.contains("000002") | ||
|
||
|
||
def test_contains_when_code_is_not_blacklisted(): | ||
assert not Blacklist.contains("11111") | ||
|
||
|
||
def test_contains_when_code_is_blacklisted(): | ||
assert Blacklist.contains("000002") | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"timestamp", | ||
[ | ||
"1900-12-31", | ||
"2020-12-31", | ||
"2029-12-31", | ||
"2030-01-01", | ||
], | ||
) | ||
def test_always_contains_when_no_due_date_is_set(timestamp): | ||
assert Blacklist.contains("000003", timestamp) | ||
|
||
|
||
def test_contains_when_before_blacklist_before_due_date(): | ||
assert Blacklist.contains("000001", "2020-12-31") | ||
|
||
|
||
def test_not_contains_when_over_the_blacklist_due_date(): | ||
assert not Blacklist.contains("000001", "2021-12-31") | ||
|
||
|
||
def test_contains_with_invalid_timestamp_format(): | ||
with pytest.raises(ValueError): | ||
Blacklist.contains("000001", "invalid-date-format") | ||
|
||
|
||
def test_read_from_empty_file(): | ||
with tempfile.NamedTemporaryFile("w") as f: | ||
with mock.patch("tradepy.config.common.blacklist_path", Path(f.name)): | ||
assert not Blacklist.read() | ||
|
||
|
||
def test_read_from_nonexistent_file(): | ||
with pytest.raises(FileNotFoundError): | ||
tradepy.config.common.blacklist_path = Path("/nonexistent/file/path.csv") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,3 @@ | ||
import os | ||
import pytest | ||
import contextlib | ||
import tempfile | ||
import pandas as pd | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import os | ||
import sys | ||
import tempfile | ||
import pytest | ||
import inspect | ||
|
||
from tradepy.core.conf import StrategyConf | ||
from tradepy.strategy.base import StrategyBase | ||
from .conftest import SampleBacktestStrategy | ||
|
||
|
||
@pytest.yield_fixture | ||
def strategy_class_import_path(): | ||
# Save the source code of SampleBacktestStrategy to a temporary file, and add the file | ||
# directory to the current system path | ||
source = inspect.getsource(SampleBacktestStrategy) | ||
|
||
with tempfile.NamedTemporaryFile(suffix=".py", mode="w+") as f: | ||
f.write("import talib\n") | ||
f.write("import pandas as pd\n") | ||
f.write("from tradepy.strategy.base import BacktestStrategy, BuyOption\n") | ||
f.write("from tradepy.decorators import tag\n") | ||
f.write("from tradepy.strategy.factors import FactorsMixin\n\n") | ||
f.write(source) | ||
f.flush() | ||
temp_module_dir = os.path.dirname(f.name) | ||
sys.path.append(temp_module_dir) | ||
|
||
try: | ||
yield f"{os.path.basename(f.name)[:-3]}.SampleBacktestStrategy" | ||
finally: | ||
sys.path.remove(temp_module_dir) | ||
|
||
|
||
@pytest.fixture | ||
def strategy_class(): | ||
return SampleBacktestStrategy | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"strategy_class_value", | ||
["strategy_class_import_path", "strategy_class"], | ||
) | ||
def test_load_strategy_class(strategy_class_value, request): | ||
conf = StrategyConf(strategy_class=request.getfixturevalue(strategy_class_value)) # type: ignore | ||
strategy_class = conf.load_strategy_class() | ||
assert issubclass(strategy_class, StrategyBase) | ||
|
||
strategy_instance = conf.load_strategy() | ||
assert isinstance(strategy_instance, StrategyBase) |
Oops, something went wrong.