Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support qmt KData auto download and history data download #229

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions src/zvt/broker/qmt/qmt_quote.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,23 +137,31 @@ def get_entity_list():


def get_kdata(
entity_id,
start_timestamp,
end_timestamp,
level=IntervalLevel.LEVEL_1DAY,
adjust_type=AdjustType.qfq,
download_history=True,
entity_id,
start_timestamp,
end_timestamp,
level=IntervalLevel.LEVEL_1DAY,
adjust_type=AdjustType.qfq,
download_history=True,
):
code = _to_qmt_code(entity_id=entity_id)
period = level.value
start_time = to_time_str(start_timestamp, fmt="YYYYMMDDHHmmss")
end_time = to_time_str(end_timestamp, fmt="YYYYMMDDHHmmss")

# download比较耗时,建议单独定时任务来做
if download_history:
xtdata.download_history_data(stock_code=code, period=period)
# print(
# f"download from {start_time} to {end_time}")
xtdata.download_history_data(
stock_code=code, period=period,
start_time=start_time, end_time=end_time
)
records = xtdata.get_market_data(
stock_list=[code],
period=period,
start_time=to_time_str(start_timestamp, fmt="YYYYMMDDHHmmss"),
end_time=to_time_str(end_timestamp, fmt="YYYYMMDDHHmmss"),
start_time=start_time,
end_time=end_time,
dividend_type=_to_qmt_dividend_type(adjust_type=adjust_type),
fill_data=False,
)
Expand Down
94 changes: 69 additions & 25 deletions src/zvt/recorders/qmt/quotes/qmt_kdata_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,45 +6,49 @@
from zvt.contract import IntervalLevel, AdjustType
from zvt.contract.api import df_to_db
from zvt.contract.recorder import FixedCycleDataRecorder
from zvt.contract.utils import evaluate_size_from_timestamp
from zvt.domain import (
Stock,
StockKdataCommon,
)
from zvt.utils.pd_utils import pd_is_not_null
from zvt.utils.time_utils import current_date, to_time_str
from zvt.utils.time_utils import current_date, to_time_str, TIME_FORMAT_DAY, TIME_FORMAT_MINUTE


class BaseQmtKdataRecorder(FixedCycleDataRecorder):
default_size = 50000
entity_provider: str = "qmt"

provider = "qmt"
download_history_data = False

def __init__(
self,
force_update=True,
sleeping_time=10,
exchanges=None,
entity_id=None,
entity_ids=None,
code=None,
codes=None,
day_data=False,
entity_filters=None,
ignore_failed=True,
real_time=False,
fix_duplicate_way="ignore",
start_timestamp=None,
end_timestamp=None,
level=IntervalLevel.LEVEL_1DAY,
kdata_use_begin_time=False,
one_day_trading_minutes=24 * 60,
adjust_type=AdjustType.qfq,
return_unfinished=False,
self,
force_update=True,
sleeping_time=10,
exchanges=None,
entity_id=None,
entity_ids=None,
code=None,
codes=None,
day_data=False,
entity_filters=None,
ignore_failed=True,
real_time=False,
fix_duplicate_way="ignore",
start_timestamp=None,
end_timestamp=None,
level=IntervalLevel.LEVEL_1DAY,
kdata_use_begin_time=False,
one_day_trading_minutes=24 * 60,
adjust_type=AdjustType.qfq,
return_unfinished=False,
download_history_data=False
) -> None:
level = IntervalLevel(level)
self.adjust_type = AdjustType(adjust_type)
self.entity_type = self.entity_schema.__name__.lower()
self.download_history_data = download_history_data

self.data_schema = get_kdata_schema(entity_type=self.entity_type, level=level, adjust_type=self.adjust_type)

Expand All @@ -68,6 +72,7 @@ def __init__(
one_day_trading_minutes,
return_unfinished,
)
self.one_day_trading_minutes = 240

def record(self, entity, start, end, size, timestamps):
if start and (self.level == IntervalLevel.LEVEL_1DAY):
Expand All @@ -81,7 +86,7 @@ def record(self, entity, start, end, size, timestamps):
end_timestamp=start,
adjust_type=self.adjust_type,
level=self.level,
download_history=False,
download_history=self.download_history_data,
)
if pd_is_not_null(check_df):
current_df = get_kdata(
Expand All @@ -107,18 +112,24 @@ def record(self, entity, start, end, size, timestamps):
if not end:
end = current_date()

# 统一高频数据习惯,减小数据更新次数,分钟K线需要直接多读1根K线,以兼容start_timestamp=9:30, end_timestamp=15:00的情况
if self.level == IntervalLevel.LEVEL_1MIN:
end += pd.Timedelta(seconds=1)

df = qmt_quote.get_kdata(
entity_id=entity.id,
start_timestamp=start,
end_timestamp=end,
adjust_type=self.adjust_type,
level=self.level,
download_history=False,
download_history=self.download_history_data,
)
time_str_fmt = TIME_FORMAT_DAY if self.level == IntervalLevel.LEVEL_1DAY else TIME_FORMAT_MINUTE
if pd_is_not_null(df):
df["entity_id"] = entity.id
df["timestamp"] = pd.to_datetime(df.index)
df["id"] = df.apply(lambda row: f"{row['entity_id']}_{to_time_str(row['timestamp'])}", axis=1)
df["id"] = df.apply(lambda row: f"{row['entity_id']}_{to_time_str(row['timestamp'], fmt=time_str_fmt)}",
axis=1)
df["provider"] = "qmt"
df["level"] = self.level.value
df["code"] = entity.code
Expand All @@ -130,6 +141,39 @@ def record(self, entity, start, end, size, timestamps):
else:
self.logger.info(f"no kdata for {entity.id}")

def evaluate_start_end_size_timestamps(self, entity):
if self.download_history_data and self.start_timestamp and self.end_timestamp:
# 历史数据可能碎片化,允许按照实际start和end之间有没有写满数据
expected_size = evaluate_size_from_timestamp(start_timestamp=self.start_timestamp,
end_timestamp=self.end_timestamp, level=self.level,
one_day_trading_minutes=self.one_day_trading_minutes)

recorded_size = self.session.query(self.data_schema).filter(
self.data_schema.entity_id == entity.id,
self.data_schema.timestamp >= self.start_timestamp,
self.data_schema.timestamp <= self.end_timestamp
).count()

if expected_size != recorded_size:
# print(f"expected_size: {expected_size}, recorded_size: {recorded_size}")
return self.start_timestamp, self.end_timestamp, self.default_size, None

start_timestamp, end_timestamp, size, timestamps = super().evaluate_start_end_size_timestamps(entity)
# start_timestamp is the last updated timestamp
if self.end_timestamp is not None:
if start_timestamp >= self.end_timestamp:
return start_timestamp, end_timestamp, 0, None
else:
size = evaluate_size_from_timestamp(
start_timestamp=start_timestamp,
level=self.level,
one_day_trading_minutes=self.one_day_trading_minutes,
end_timestamp=self.end_timestamp,
)
return start_timestamp, self.end_timestamp, size, timestamps

return start_timestamp, end_timestamp, size, timestamps


class QMTStockKdataRecorder(BaseQmtKdataRecorder):
entity_schema = Stock
Expand All @@ -138,7 +182,7 @@ class QMTStockKdataRecorder(BaseQmtKdataRecorder):

if __name__ == "__main__":
# Stock.record_data(provider="qmt")
QMTStockKdataRecorder(entity_id="stock_sz_301611", adjust_type=AdjustType.qfq).run()
QMTStockKdataRecorder(entity_id="stock_sz_002231", adjust_type=AdjustType.qfq, level=IntervalLevel.LEVEL_1MIN).run()

# the __all__ is generated
__all__ = ["BaseQmtKdataRecorder", "QMTStockKdataRecorder"]
65 changes: 65 additions & 0 deletions tests/recorders/qmt/test_qmt_1m_stock_recorder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import unittest

import pandas as pd

from zvt.contract import AdjustType, IntervalLevel
from zvt.domain import Stock1mKdata
from zvt.recorders.qmt.quotes import QMTStockKdataRecorder


class MyTestCase(unittest.TestCase):
def test_qmt_stock_recorder(self):
# 注意这里回测时,timestamp是当前时间,是需要15:00这一根额外的k线用于分析完整当日数据的,
# 但是实时交易时,可能取不到这根K线
trade_start_minute = pd.Timestamp('2024-01-02 9:30')
trade_end_minute = pd.Timestamp('2024-01-02 15:00')
old_trade_start_minute = pd.Timestamp('2014-01-02 9:30')
old_trade_end_minute = pd.Timestamp('2014-01-02 15:00')
recorder = QMTStockKdataRecorder(
entity_id="stock_sz_000488", adjust_type=AdjustType.qfq, level=IntervalLevel.LEVEL_1MIN,
start_timestamp=trade_start_minute, end_timestamp=trade_end_minute,
download_history_data=True
)
# 先手动删除了这股票的全部数据,恢复初始状态
recorder.session.query(recorder.data_schema).filter(
recorder.data_schema.entity_id == "stock_sz_000488",
recorder.data_schema.timestamp >= trade_start_minute,
recorder.data_schema.timestamp <= trade_end_minute
).delete()
recorder.session.query(recorder.data_schema).filter(
recorder.data_schema.entity_id == "stock_sz_000488",
recorder.data_schema.timestamp >= old_trade_start_minute,
recorder.data_schema.timestamp <= old_trade_end_minute
).delete()
recorder.run()
records = Stock1mKdata.query_data(provider="qmt", entity_ids=["stock_sz_000488"],
start_timestamp=trade_start_minute, end_timestamp=trade_end_minute, )
self.assertEqual(records.shape[0], 241)
# 不使用download_history_data参数则无视历史数据,直接认为
QMTStockKdataRecorder(
entity_id="stock_sz_000488", adjust_type=AdjustType.qfq, level=IntervalLevel.LEVEL_1MIN,
start_timestamp=old_trade_start_minute, end_timestamp=old_trade_end_minute,
).run()
records = Stock1mKdata.query_data(
provider="qmt", entity_ids=["stock_sz_000488"],
start_timestamp=old_trade_start_minute, end_timestamp=old_trade_end_minute,

)
self.assertEqual(records.shape[0], 0)

# 历史数据也要下载,因为不太可能一次性把全部数据爬下来,边跑回测边下载河里一点
QMTStockKdataRecorder(
entity_id="stock_sz_000488", adjust_type=AdjustType.qfq, level=IntervalLevel.LEVEL_1MIN,
start_timestamp=old_trade_start_minute, end_timestamp=old_trade_end_minute,
download_history_data=True
).run()
records = Stock1mKdata.query_data(
provider="qmt", entity_ids=["stock_sz_000488"],
start_timestamp=old_trade_start_minute, end_timestamp=old_trade_end_minute,

)
self.assertEqual(records.shape[0], 241)


if __name__ == '__main__':
unittest.main()