Skip to content

Commit

Permalink
Merge pull request #49 from namoshizun/test-cases-part-3
Browse files Browse the repository at this point in the history
* Test cases for: Blacklist, Backtester, Optimization,
* `StrategyConf.strategy_class`: the value can be the class itself (no need to always save the strategy class code to a importable path)
  • Loading branch information
namoshizun authored Oct 28, 2023
2 parents af9d887 + 3829979 commit 00d2a61
Show file tree
Hide file tree
Showing 25 changed files with 755 additions and 241 deletions.
2 changes: 2 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ omit =
tradepy/cli/*,
tradepy/trade_cal.py
source = tradepy
# multiprocessing = True
# parallel = true
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ results/
backup/
playground/
.ipynb_checkpoints/

test.py

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
287 changes: 149 additions & 138 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ celery = { extras = ["redis"], version = "5.3.0" }
uvicorn = { extras = ["standard"], version = "^0.21.1" }
ta-lib = "0.4.26"
quantstats = "^0.0.59"
dask = { extras = ["diagnostics", "distributed"], version = "^2023.5.0" }
dask = { extras = ["diagnostics", "distributed"], version = "2023.5.0" }
tabula-py = "^2.7.0"
pyyaml = "^6.0.1"
scikit-learn = "^1.3.0"
Expand Down
38 changes: 37 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os
import pytest
import tempfile
import talib
import pandas as pd
from pathlib import Path
from loguru import logger

Expand All @@ -14,6 +17,9 @@
)
from tradepy.depot.stocks import StockListingDepot, StocksDailyBarsDepot
from tradepy.depot.misc import AdjustFactorDepot
from tradepy.strategy.base import BacktestStrategy, BuyOption
from tradepy.strategy.factors import FactorsMixin
from tradepy.decorators import tag
from .fixtures_data.load import load_dataset


Expand All @@ -24,11 +30,12 @@ def init_tradepy_config():
return

working_dir = tempfile.TemporaryDirectory()
working_dir_path = Path(working_dir.name)
logger.info(f"Initializing tradepy config... Temporary database dir: {working_dir}")
tradepy.config = TradePyConf(
common=CommonConf(
mode="backtest",
database_dir=Path(working_dir.name),
database_dir=working_dir_path,
trade_lot_vol=100,
blacklist_path=None,
redis=None,
Expand All @@ -46,7 +53,13 @@ def init_tradepy_config():
schedules=SchedulesConf(), # type: ignore
notifications=None,
)

tradepy.config.save_to_config_file(
temp_config_path := working_dir_path / "config.yaml"
)
os.environ["TRADEPY_CONFIG_FILE"] = str(temp_config_path)
yield tradepy.config
os.environ.pop("TRADEPY_CONFIG_FILE")
working_dir.cleanup()


Expand Down Expand Up @@ -74,3 +87,26 @@ def local_stocks_day_k_df():
@pytest.fixture
def local_adjust_factors_df():
return AdjustFactorDepot.load()


class SampleBacktestStrategy(BacktestStrategy, FactorsMixin):
@tag(notna=True)
def vol_ref1(self, vol):
return vol.shift(1)

@tag(notna=True, outputs=["boll_upper", "boll_middle", "boll_lower"])
def boll_21(self, close):
return talib.BBANDS(close, 21, 2, 2)

def should_buy(self, sma5, boll_lower, close, vol, vol_ref1) -> BuyOption | None:
if close <= boll_lower:
return close, 1

if close >= sma5 and vol > vol_ref1:
return close, 1

def should_sell(self, close, boll_upper) -> bool:
return close >= boll_upper

def pre_process(self, bars_df: pd.DataFrame) -> pd.DataFrame:
return bars_df.query('market != "科创板"').copy()
2 changes: 2 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@
skip_if_in_ci = lambda: pytest.mark.skipif(
os.environ.get("CI", "false").lower() == "true",
reason="Skip this test in CI environment",
# Because the network access from the Github CI environment to
# the external APIs (which are hosted in China) is unstable
)
138 changes: 138 additions & 0 deletions tests/test_backtester.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import io
import pytest
import pandas as pd
from unittest import mock

from tradepy.trade_book.trade_book import TradeBook
from tradepy.core.conf import BacktestConf, StrategyConf, SlippageConf, SL_TP_Order
from tradepy.backtest.backtester import Backtester
from .conftest import SampleBacktestStrategy


@pytest.fixture
def backtest_conf():
return BacktestConf(
cash_amount=1e6,
broker_commission_rate=0.01,
min_broker_commission_fee=0,
use_minute_k=False,
sl_tf_order="stop loss first",
strategy=StrategyConf(
stop_loss=3,
take_profit=4,
take_profit_slip=SlippageConf(method="max_jump", params=1),
stop_loss_slip=SlippageConf(method="max_pct", params=0.1),
max_position_opens=10,
max_position_size=0.25,
min_trade_amount=8000,
), # type: ignore
)


@pytest.fixture
def sample_backtester(backtest_conf: BacktestConf):
return Backtester(backtest_conf)


@pytest.fixture
def sample_strategy(backtest_conf: BacktestConf):
return SampleBacktestStrategy(backtest_conf.strategy)


@pytest.fixture
def sample_computed_day_k_df(
local_stocks_day_k_df: pd.DataFrame,
sample_strategy: SampleBacktestStrategy,
request: pytest.FixtureRequest,
):
cache_key = "cached_computed_day_k"
if raw_csv_str := request.config.cache.get(cache_key, None):
return pd.read_csv(io.StringIO(raw_csv_str), dtype={"code": str})

result_df: pd.DataFrame = sample_strategy.compute_all_indicators_df(
local_stocks_day_k_df
)
buf = io.StringIO()
result_df.reset_index(drop=True).to_csv(buf)
buf.seek(0)
request.config.cache.set(cache_key, buf.read())
return result_df


@pytest.fixture
def sample_stock_data():
_ = None
return pd.DataFrame(
[
# code, close, sma5, boll_lower, boll_upper, vol, vol_ref1, comment
["000001", 10, 10, 11, _, _, _, "buy: close <= boll_lower"],
["000002", 11, 10, _, _, 100, 90, "buy: sma5 breakthrough"],
["000003", 12, 15, 10, _, _, _, "noop"],
["000004", 13, _, _, 10, _, _, "sell: close >= boll_upper"],
],
columns=[
"code",
"close",
"sma5",
"boll_lower",
"boll_upper",
"vol",
"vol_ref1",
"comment",
],
).set_index("code")


def test_get_buy_options(
sample_backtester: Backtester,
sample_strategy: SampleBacktestStrategy,
sample_stock_data: pd.DataFrame,
):
selector = sample_stock_data["comment"].str.contains("buy")
expect_buy_stocks = sample_stock_data[selector].index

# Should buy the stocks that trigger the buy signals
buy_options = sample_backtester.get_buy_options(sample_stock_data, sample_strategy)
assert set(buy_options.index) == set(expect_buy_stocks)

# Check the returned order prices
for code, option in buy_options.iterrows():
assert option["order_price"] == sample_stock_data.loc[code, "close"]


def test_get_close_signals(
sample_backtester: Backtester,
sample_strategy: SampleBacktestStrategy,
sample_stock_data: pd.DataFrame,
):
selector = sample_stock_data["comment"].str.contains("sell")
expect_sell_stocks = sample_stock_data[selector].index

# Should buy the stocks that trigger the buy signals
with mock.patch(
"tradepy.core.holdings.Holdings.position_codes",
) as mock_position_codes:
mock_position_codes.__get__ = mock.Mock(
return_value=sample_stock_data.index.tolist()
)
sell_codes = sample_backtester.get_close_signals(
sample_stock_data, sample_strategy
)
assert set(sell_codes) == set(expect_sell_stocks)


@pytest.mark.parametrize(
"sl_tf_order",
["stop loss first", "take profit first", "random"],
)
def test_day_k_trading(
sl_tf_order: SL_TP_Order,
sample_computed_day_k_df: pd.DataFrame,
sample_strategy: SampleBacktestStrategy,
backtest_conf: BacktestConf,
):
backtest_conf.sl_tf_order = sl_tf_order
backtester = Backtester(backtest_conf)
trade_book = backtester.trade(sample_computed_day_k_df, sample_strategy)

assert isinstance(trade_book, TradeBook)
76 changes: 76 additions & 0 deletions tests/test_blacklist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from unittest import mock
import pytest
import tempfile
import pandas as pd
from pathlib import Path

import tradepy
from tradepy.blacklist import Blacklist


@pytest.yield_fixture(autouse=True)
def temp_blacklist_file():
with tempfile.NamedTemporaryFile("w") as f:
stocks = pd.DataFrame(
[
["000001", "2021-01-01"],
["000002", "2030-01-02"],
["000003", None],
],
columns=["code", "until"],
).set_index("code")
stocks.to_csv(f.name)

with mock.patch("tradepy.config.common.blacklist_path", Path(f.name)):
yield f.name
Blacklist.purge_cache()


def test_contains_when_blacklist_is_empty():
with mock.patch("tradepy.config.common.blacklist_path", None):
assert not Blacklist.contains("000002")


def test_contains_when_code_is_not_blacklisted():
assert not Blacklist.contains("11111")


def test_contains_when_code_is_blacklisted():
assert Blacklist.contains("000002")


@pytest.mark.parametrize(
"timestamp",
[
"1900-12-31",
"2020-12-31",
"2029-12-31",
"2030-01-01",
],
)
def test_always_contains_when_no_due_date_is_set(timestamp):
assert Blacklist.contains("000003", timestamp)


def test_contains_when_before_blacklist_before_due_date():
assert Blacklist.contains("000001", "2020-12-31")


def test_not_contains_when_over_the_blacklist_due_date():
assert not Blacklist.contains("000001", "2021-12-31")


def test_contains_with_invalid_timestamp_format():
with pytest.raises(ValueError):
Blacklist.contains("000001", "invalid-date-format")


def test_read_from_empty_file():
with tempfile.NamedTemporaryFile("w") as f:
with mock.patch("tradepy.config.common.blacklist_path", Path(f.name)):
assert not Blacklist.read()


def test_read_from_nonexistent_file():
with pytest.raises(FileNotFoundError):
tradepy.config.common.blacklist_path = Path("/nonexistent/file/path.csv")
2 changes: 0 additions & 2 deletions tests/test_collectors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import os
import pytest
import contextlib
import tempfile
import pandas as pd
Expand Down
50 changes: 50 additions & 0 deletions tests/test_conf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import os
import sys
import tempfile
import pytest
import inspect

from tradepy.core.conf import StrategyConf
from tradepy.strategy.base import StrategyBase
from .conftest import SampleBacktestStrategy


@pytest.yield_fixture
def strategy_class_import_path():
# Save the source code of SampleBacktestStrategy to a temporary file, and add the file
# directory to the current system path
source = inspect.getsource(SampleBacktestStrategy)

with tempfile.NamedTemporaryFile(suffix=".py", mode="w+") as f:
f.write("import talib\n")
f.write("import pandas as pd\n")
f.write("from tradepy.strategy.base import BacktestStrategy, BuyOption\n")
f.write("from tradepy.decorators import tag\n")
f.write("from tradepy.strategy.factors import FactorsMixin\n\n")
f.write(source)
f.flush()
temp_module_dir = os.path.dirname(f.name)
sys.path.append(temp_module_dir)

try:
yield f"{os.path.basename(f.name)[:-3]}.SampleBacktestStrategy"
finally:
sys.path.remove(temp_module_dir)


@pytest.fixture
def strategy_class():
return SampleBacktestStrategy


@pytest.mark.parametrize(
"strategy_class_value",
["strategy_class_import_path", "strategy_class"],
)
def test_load_strategy_class(strategy_class_value, request):
conf = StrategyConf(strategy_class=request.getfixturevalue(strategy_class_value)) # type: ignore
strategy_class = conf.load_strategy_class()
assert issubclass(strategy_class, StrategyBase)

strategy_instance = conf.load_strategy()
assert isinstance(strategy_instance, StrategyBase)
Loading

0 comments on commit 00d2a61

Please sign in to comment.