diff --git a/src/zvt/broker/qmt/qmt_quote.py b/src/zvt/broker/qmt/qmt_quote.py index 562aca0d..0af64f5d 100644 --- a/src/zvt/broker/qmt/qmt_quote.py +++ b/src/zvt/broker/qmt/qmt_quote.py @@ -137,23 +137,31 @@ def get_entity_list(): def get_kdata( - entity_id, - start_timestamp, - end_timestamp, - level=IntervalLevel.LEVEL_1DAY, - adjust_type=AdjustType.qfq, - download_history=True, + entity_id, + start_timestamp, + end_timestamp, + level=IntervalLevel.LEVEL_1DAY, + adjust_type=AdjustType.qfq, + download_history=True, ): code = _to_qmt_code(entity_id=entity_id) period = level.value + start_time = to_time_str(start_timestamp, fmt="YYYYMMDDHHmmss") + end_time = to_time_str(end_timestamp, fmt="YYYYMMDDHHmmss") + # download比较耗时,建议单独定时任务来做 if download_history: - xtdata.download_history_data(stock_code=code, period=period) + # print( + # f"download from {start_time} to {end_time}") + xtdata.download_history_data( + stock_code=code, period=period, + start_time=start_time, end_time=end_time + ) records = xtdata.get_market_data( stock_list=[code], period=period, - start_time=to_time_str(start_timestamp, fmt="YYYYMMDDHHmmss"), - end_time=to_time_str(end_timestamp, fmt="YYYYMMDDHHmmss"), + start_time=start_time, + end_time=end_time, dividend_type=_to_qmt_dividend_type(adjust_type=adjust_type), fill_data=False, ) diff --git a/src/zvt/recorders/qmt/quotes/qmt_kdata_recorder.py b/src/zvt/recorders/qmt/quotes/qmt_kdata_recorder.py index cfc665fa..955a9059 100644 --- a/src/zvt/recorders/qmt/quotes/qmt_kdata_recorder.py +++ b/src/zvt/recorders/qmt/quotes/qmt_kdata_recorder.py @@ -6,12 +6,13 @@ from zvt.contract import IntervalLevel, AdjustType from zvt.contract.api import df_to_db from zvt.contract.recorder import FixedCycleDataRecorder +from zvt.contract.utils import evaluate_size_from_timestamp from zvt.domain import ( Stock, StockKdataCommon, ) from zvt.utils.pd_utils import pd_is_not_null -from zvt.utils.time_utils import current_date, to_time_str +from zvt.utils.time_utils import current_date, to_time_str, TIME_FORMAT_DAY, TIME_FORMAT_MINUTE class BaseQmtKdataRecorder(FixedCycleDataRecorder): @@ -19,32 +20,35 @@ class BaseQmtKdataRecorder(FixedCycleDataRecorder): entity_provider: str = "qmt" provider = "qmt" + download_history_data = False def __init__( - self, - force_update=True, - sleeping_time=10, - exchanges=None, - entity_id=None, - entity_ids=None, - code=None, - codes=None, - day_data=False, - entity_filters=None, - ignore_failed=True, - real_time=False, - fix_duplicate_way="ignore", - start_timestamp=None, - end_timestamp=None, - level=IntervalLevel.LEVEL_1DAY, - kdata_use_begin_time=False, - one_day_trading_minutes=24 * 60, - adjust_type=AdjustType.qfq, - return_unfinished=False, + self, + force_update=True, + sleeping_time=10, + exchanges=None, + entity_id=None, + entity_ids=None, + code=None, + codes=None, + day_data=False, + entity_filters=None, + ignore_failed=True, + real_time=False, + fix_duplicate_way="ignore", + start_timestamp=None, + end_timestamp=None, + level=IntervalLevel.LEVEL_1DAY, + kdata_use_begin_time=False, + one_day_trading_minutes=24 * 60, + adjust_type=AdjustType.qfq, + return_unfinished=False, + download_history_data=False ) -> None: level = IntervalLevel(level) self.adjust_type = AdjustType(adjust_type) self.entity_type = self.entity_schema.__name__.lower() + self.download_history_data = download_history_data self.data_schema = get_kdata_schema(entity_type=self.entity_type, level=level, adjust_type=self.adjust_type) @@ -68,6 +72,7 @@ def __init__( one_day_trading_minutes, return_unfinished, ) + self.one_day_trading_minutes = 240 def record(self, entity, start, end, size, timestamps): if start and (self.level == IntervalLevel.LEVEL_1DAY): @@ -81,7 +86,7 @@ def record(self, entity, start, end, size, timestamps): end_timestamp=start, adjust_type=self.adjust_type, level=self.level, - download_history=False, + download_history=self.download_history_data, ) if pd_is_not_null(check_df): current_df = get_kdata( @@ -107,18 +112,24 @@ def record(self, entity, start, end, size, timestamps): if not end: end = current_date() + # 统一高频数据习惯,减小数据更新次数,分钟K线需要直接多读1根K线,以兼容start_timestamp=9:30, end_timestamp=15:00的情况 + if self.level == IntervalLevel.LEVEL_1MIN: + end += pd.Timedelta(seconds=1) + df = qmt_quote.get_kdata( entity_id=entity.id, start_timestamp=start, end_timestamp=end, adjust_type=self.adjust_type, level=self.level, - download_history=False, + download_history=self.download_history_data, ) + time_str_fmt = TIME_FORMAT_DAY if self.level == IntervalLevel.LEVEL_1DAY else TIME_FORMAT_MINUTE if pd_is_not_null(df): df["entity_id"] = entity.id df["timestamp"] = pd.to_datetime(df.index) - df["id"] = df.apply(lambda row: f"{row['entity_id']}_{to_time_str(row['timestamp'])}", axis=1) + df["id"] = df.apply(lambda row: f"{row['entity_id']}_{to_time_str(row['timestamp'], fmt=time_str_fmt)}", + axis=1) df["provider"] = "qmt" df["level"] = self.level.value df["code"] = entity.code @@ -130,6 +141,39 @@ def record(self, entity, start, end, size, timestamps): else: self.logger.info(f"no kdata for {entity.id}") + def evaluate_start_end_size_timestamps(self, entity): + if self.download_history_data and self.start_timestamp and self.end_timestamp: + # 历史数据可能碎片化,允许按照实际start和end之间有没有写满数据 + expected_size = evaluate_size_from_timestamp(start_timestamp=self.start_timestamp, + end_timestamp=self.end_timestamp, level=self.level, + one_day_trading_minutes=self.one_day_trading_minutes) + + recorded_size = self.session.query(self.data_schema).filter( + self.data_schema.entity_id == entity.id, + self.data_schema.timestamp >= self.start_timestamp, + self.data_schema.timestamp <= self.end_timestamp + ).count() + + if expected_size != recorded_size: + # print(f"expected_size: {expected_size}, recorded_size: {recorded_size}") + return self.start_timestamp, self.end_timestamp, self.default_size, None + + start_timestamp, end_timestamp, size, timestamps = super().evaluate_start_end_size_timestamps(entity) + # start_timestamp is the last updated timestamp + if self.end_timestamp is not None: + if start_timestamp >= self.end_timestamp: + return start_timestamp, end_timestamp, 0, None + else: + size = evaluate_size_from_timestamp( + start_timestamp=start_timestamp, + level=self.level, + one_day_trading_minutes=self.one_day_trading_minutes, + end_timestamp=self.end_timestamp, + ) + return start_timestamp, self.end_timestamp, size, timestamps + + return start_timestamp, end_timestamp, size, timestamps + class QMTStockKdataRecorder(BaseQmtKdataRecorder): entity_schema = Stock @@ -138,7 +182,7 @@ class QMTStockKdataRecorder(BaseQmtKdataRecorder): if __name__ == "__main__": # Stock.record_data(provider="qmt") - QMTStockKdataRecorder(entity_id="stock_sz_301611", adjust_type=AdjustType.qfq).run() + QMTStockKdataRecorder(entity_id="stock_sz_002231", adjust_type=AdjustType.qfq, level=IntervalLevel.LEVEL_1MIN).run() # the __all__ is generated __all__ = ["BaseQmtKdataRecorder", "QMTStockKdataRecorder"] diff --git a/tests/recorders/qmt/test_qmt_1m_stock_recorder.py b/tests/recorders/qmt/test_qmt_1m_stock_recorder.py new file mode 100644 index 00000000..32dee1d8 --- /dev/null +++ b/tests/recorders/qmt/test_qmt_1m_stock_recorder.py @@ -0,0 +1,65 @@ +import unittest + +import pandas as pd + +from zvt.contract import AdjustType, IntervalLevel +from zvt.domain import Stock1mKdata +from zvt.recorders.qmt.quotes import QMTStockKdataRecorder + + +class MyTestCase(unittest.TestCase): + def test_qmt_stock_recorder(self): + # 注意这里回测时,timestamp是当前时间,是需要15:00这一根额外的k线用于分析完整当日数据的, + # 但是实时交易时,可能取不到这根K线 + trade_start_minute = pd.Timestamp('2024-01-02 9:30') + trade_end_minute = pd.Timestamp('2024-01-02 15:00') + old_trade_start_minute = pd.Timestamp('2014-01-02 9:30') + old_trade_end_minute = pd.Timestamp('2014-01-02 15:00') + recorder = QMTStockKdataRecorder( + entity_id="stock_sz_000488", adjust_type=AdjustType.qfq, level=IntervalLevel.LEVEL_1MIN, + start_timestamp=trade_start_minute, end_timestamp=trade_end_minute, + download_history_data=True + ) + # 先手动删除了这股票的全部数据,恢复初始状态 + recorder.session.query(recorder.data_schema).filter( + recorder.data_schema.entity_id == "stock_sz_000488", + recorder.data_schema.timestamp >= trade_start_minute, + recorder.data_schema.timestamp <= trade_end_minute + ).delete() + recorder.session.query(recorder.data_schema).filter( + recorder.data_schema.entity_id == "stock_sz_000488", + recorder.data_schema.timestamp >= old_trade_start_minute, + recorder.data_schema.timestamp <= old_trade_end_minute + ).delete() + recorder.run() + records = Stock1mKdata.query_data(provider="qmt", entity_ids=["stock_sz_000488"], + start_timestamp=trade_start_minute, end_timestamp=trade_end_minute, ) + self.assertEqual(records.shape[0], 241) + # 不使用download_history_data参数则无视历史数据,直接认为 + QMTStockKdataRecorder( + entity_id="stock_sz_000488", adjust_type=AdjustType.qfq, level=IntervalLevel.LEVEL_1MIN, + start_timestamp=old_trade_start_minute, end_timestamp=old_trade_end_minute, + ).run() + records = Stock1mKdata.query_data( + provider="qmt", entity_ids=["stock_sz_000488"], + start_timestamp=old_trade_start_minute, end_timestamp=old_trade_end_minute, + + ) + self.assertEqual(records.shape[0], 0) + + # 历史数据也要下载,因为不太可能一次性把全部数据爬下来,边跑回测边下载河里一点 + QMTStockKdataRecorder( + entity_id="stock_sz_000488", adjust_type=AdjustType.qfq, level=IntervalLevel.LEVEL_1MIN, + start_timestamp=old_trade_start_minute, end_timestamp=old_trade_end_minute, + download_history_data=True + ).run() + records = Stock1mKdata.query_data( + provider="qmt", entity_ids=["stock_sz_000488"], + start_timestamp=old_trade_start_minute, end_timestamp=old_trade_end_minute, + + ) + self.assertEqual(records.shape[0], 241) + + +if __name__ == '__main__': + unittest.main()