diff --git a/examples/migration.py b/examples/migration.py new file mode 100644 index 00000000..52f27d1b --- /dev/null +++ b/examples/migration.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +from datetime import datetime +from typing import Any, Dict + +from pydantic import BaseModel, ConfigDict, Json +from sqlalchemy import Column, String, JSON +from sqlalchemy.orm import declarative_base + +from zvt.contract.api import get_db_session +from zvt.contract.register import register_schema +from zvt.contract.schema import Mixin + +ZvtInfoBase = declarative_base() + + +class User(Mixin, ZvtInfoBase): + __tablename__ = "user" + added_col = Column(String) + json_col = Column(JSON) + + +class UserModel(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: str + entity_id: str + timestamp: datetime + added_col: str + json_col: Dict + + +register_schema(providers=["zvt"], db_name="test", schema_base=ZvtInfoBase) + +if __name__ == "__main__": + user_model = UserModel( + id="user_cn_jack_2020-01-01", + entity_id="user_cn_jack", + timestamp="2020-01-01", + added_col="test", + json_col={"a": 1}, + ) + session = get_db_session(provider="zvt", data_schema=User) + + user = session.query(User).filter(User.id == "user_cn_jack_2020-01-01").first() + print(UserModel.validate(user)) diff --git a/examples/ml/sgd.py b/examples/ml/sgd.py index c10bbf3d..4b0fea85 100644 --- a/examples/ml/sgd.py +++ b/examples/ml/sgd.py @@ -7,7 +7,7 @@ def sgd_classification(): - machine = MaStockMLMachine(entity_ids=["stock_sz_000001"], label_method="behavior_cls") + machine = MaStockMLMachine(data_provider="em", entity_ids=["stock_sz_000001"], label_method="behavior_cls") clf = make_pipeline(StandardScaler(), SGDClassifier(max_iter=1000, tol=1e-3)) machine.train(model=clf) machine.predict() @@ -15,7 +15,7 @@ def sgd_classification(): def sgd_regressor(): - machine = MaStockMLMachine(entity_ids=["stock_sz_000001"], label_method="raw") + machine = MaStockMLMachine(data_provider="em", entity_ids=["stock_sz_000001"], label_method="raw") reg = make_pipeline(StandardScaler(), SGDRegressor(max_iter=1000, tol=1e-3)) machine.train(model=reg) machine.predict() diff --git a/examples/report_utils.py b/examples/report_utils.py index 6c7c1e72..3c28d90d 100644 --- a/examples/report_utils.py +++ b/examples/report_utils.py @@ -39,6 +39,10 @@ def inform( entities = get_entities( provider=entity_provider, entity_type=entity_type, entity_ids=entity_ids, return_type="domain" ) + entities = [entity for entity in entities if entity.entity_id in entity_ids] + print(len(entities)) + print(len(entity_ids)) + assert len(entities) == len(entity_ids) if group_by_topic and (entity_type == "stock"): StockNews.record_data( diff --git a/examples/trader/follow_ii_trader.py b/examples/trader/follow_ii_trader.py index 5348d398..96d6bf88 100644 --- a/examples/trader/follow_ii_trader.py +++ b/examples/trader/follow_ii_trader.py @@ -44,7 +44,7 @@ def on_time(self, timestamp: pd.Timestamp): if __name__ == "__main__": code = "600519" - Stock.record_data(provider="eastmoney") + Stock.record_data(provider="em") Stock1dKdata.record_data(code=code, provider="em") StockActorSummary.record_data(code=code, provider="em") FollowIITrader( diff --git a/requirements.txt b/requirements.txt index ee637315..eeea82da 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ requests==2.31.0 SQLAlchemy==2.0.28 pandas==2.0.3 +pydantic==2.6.4 arrow==1.2.3 openpyxl==3.1.1 demjson3==3.0.6 @@ -11,4 +12,5 @@ dash==2.8.1 jqdatapy==0.1.8 dash-bootstrap-components==1.3.1 dash_daq==0.5.0 -scikit-learn==1.2.1 \ No newline at end of file +scikit-learn==1.2.1 +fastapi==0.110.0 \ No newline at end of file diff --git a/src/zvt/api/intent.py b/src/zvt/api/intent.py index 35b4507d..687dc878 100644 --- a/src/zvt/api/intent.py +++ b/src/zvt/api/intent.py @@ -127,7 +127,7 @@ def composite_df(df): drawer.draw_pie(show=True) -def composite_all(data_schema, column, timestamp, entity_ids=None, filters=None): +def composite_all(data_schema, column, timestamp, provider=None, entity_ids=None, filters=None): if type(column) is not str: column = column.name if filters: @@ -135,7 +135,11 @@ def composite_all(data_schema, column, timestamp, entity_ids=None, filters=None) else: filters = [data_schema.timestamp == to_pd_timestamp(timestamp)] df = data_schema.query_data( - entity_ids=entity_ids, columns=["entity_id", "timestamp", column], filters=filters, index="entity_id" + provider=provider, + entity_ids=entity_ids, + columns=["entity_id", "timestamp", column], + filters=filters, + index="entity_id", ) entity_type, exchange, _ = decode_entity_id(df["entity_id"].iloc[0]) pie_df = pd.DataFrame(columns=df.index, data=[df[column].tolist()]) diff --git a/src/zvt/api/selector.py b/src/zvt/api/selector.py index 78413318..2ab90135 100644 --- a/src/zvt/api/selector.py +++ b/src/zvt/api/selector.py @@ -35,7 +35,7 @@ def get_entity_ids_by_filter( entity_ids=None, ignore_bj=False, ): - filters = [] + filters = [entity_schema.timestamp.isnot(None)] if not target_date: target_date = current_date() if ignore_new_stock: diff --git a/src/zvt/contract/api.py b/src/zvt/contract/api.py index 3da2ae0e..ab41453a 100644 --- a/src/zvt/contract/api.py +++ b/src/zvt/contract/api.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +import json import logging import os import platform @@ -58,11 +59,17 @@ def get_db_engine( engine_key = "{}_{}".format(provider, db_name) db_engine = zvt_context.db_engine_map.get(engine_key) if not db_engine: - db_engine = create_engine("sqlite:///" + db_path, echo=False) + db_engine = create_engine( + "sqlite:///" + db_path, echo=False, json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False) + ) zvt_context.db_engine_map[engine_key] = db_engine return db_engine +def get_providers() -> List[str]: + return zvt_context.providers + + def get_schemas(provider: str) -> List[DeclarativeMeta]: """ get domain schemas supported by the provider @@ -99,6 +106,7 @@ def get_db_session(provider: str, db_name: str = None, data_schema: object = Non return get_db_session_factory(provider, db_name, data_schema)() session = zvt_context.sessions.get(session_key) + # FIXME: should not maintain global session if not session: session = get_db_session_factory(provider, db_name, data_schema)() zvt_context.sessions[session_key] = session @@ -125,6 +133,9 @@ def get_db_session_factory(provider: str, db_name: str = None, data_schema: obje return session +DBSession = get_db_session_factory + + def get_entity_schema(entity_type: str) -> Type[TradableEntity]: """ get entity schema from name @@ -664,9 +675,11 @@ def get_entity_ids( __all__ = [ "_get_db_name", "get_db_engine", + "get_providers", "get_schemas", "get_db_session", "get_db_session_factory", + "DBSession", "get_entity_schema", "get_schema_by_name", "get_schema_columns", diff --git a/src/zvt/contract/model.py b/src/zvt/contract/model.py new file mode 100644 index 00000000..7dfaa2a4 --- /dev/null +++ b/src/zvt/contract/model.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +from datetime import datetime + +from pydantic import BaseModel, ConfigDict + + +class MixinModel(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: str + entity_id: str + timestamp: datetime diff --git a/src/zvt/contract/zvt_info.py b/src/zvt/contract/zvt_info.py index b2fa4c90..1c35dd3c 100644 --- a/src/zvt/contract/zvt_info.py +++ b/src/zvt/contract/zvt_info.py @@ -21,7 +21,7 @@ class RecorderState(ZvtInfoBase, StateMixin): Schema for storing recorder state """ - __tablename__ = "recoder_state" + __tablename__ = "recorder_state" class TaggerState(ZvtInfoBase, StateMixin): diff --git a/src/zvt/domain/quotes/stock/__init__.py b/src/zvt/domain/quotes/stock/__init__.py index 07c8599b..b967e13b 100644 --- a/src/zvt/domain/quotes/stock/__init__.py +++ b/src/zvt/domain/quotes/stock/__init__.py @@ -67,6 +67,12 @@ __all__ += _stock_30m_hfq_kdata_all +# import all from submodule stock_l2quote_hfq_kdata +from .stock_l2quote_hfq_kdata import * +from .stock_l2quote_hfq_kdata import __all__ as _stock_l2quote_hfq_kdata_all + +__all__ += _stock_l2quote_hfq_kdata_all + # import all from submodule stock_1mon_hfq_kdata from .stock_1mon_hfq_kdata import * from .stock_1mon_hfq_kdata import __all__ as _stock_1mon_hfq_kdata_all @@ -91,6 +97,12 @@ __all__ += _stock_1h_hfq_kdata_all +# import all from submodule stock_l2quote_kdata +from .stock_l2quote_kdata import * +from .stock_l2quote_kdata import __all__ as _stock_l2quote_kdata_all + +__all__ += _stock_l2quote_kdata_all + # import all from submodule stock_1m_kdata from .stock_1m_kdata import * from .stock_1m_kdata import __all__ as _stock_1m_kdata_all diff --git a/src/zvt/domain/quotes/stock/stock_15m_hfq_kdata.py b/src/zvt/domain/quotes/stock/stock_15m_hfq_kdata.py index eac6b1a8..33bbe1bc 100644 --- a/src/zvt/domain/quotes/stock/stock_15m_hfq_kdata.py +++ b/src/zvt/domain/quotes/stock/stock_15m_hfq_kdata.py @@ -13,7 +13,7 @@ class Stock15mHfqKdata(KdataBase, StockKdataCommon): register_schema( - providers=["joinquant", "em"], db_name="stock_15m_hfq_kdata", schema_base=KdataBase, entity_type="stock" + providers=["em", "joinquant"], db_name="stock_15m_hfq_kdata", schema_base=KdataBase, entity_type="stock" ) # the __all__ is generated diff --git a/src/zvt/domain/quotes/stock/stock_15m_kdata.py b/src/zvt/domain/quotes/stock/stock_15m_kdata.py index 879eecad..6dc8fc23 100644 --- a/src/zvt/domain/quotes/stock/stock_15m_kdata.py +++ b/src/zvt/domain/quotes/stock/stock_15m_kdata.py @@ -12,7 +12,7 @@ class Stock15mKdata(KdataBase, StockKdataCommon): __tablename__ = "stock_15m_kdata" -register_schema(providers=["joinquant", "em"], db_name="stock_15m_kdata", schema_base=KdataBase, entity_type="stock") +register_schema(providers=["em", "joinquant"], db_name="stock_15m_kdata", schema_base=KdataBase, entity_type="stock") # the __all__ is generated __all__ = ["Stock15mKdata"] diff --git a/src/zvt/domain/quotes/stock/stock_1d_hfq_kdata.py b/src/zvt/domain/quotes/stock/stock_1d_hfq_kdata.py index be20a0d0..2a8fcda7 100644 --- a/src/zvt/domain/quotes/stock/stock_1d_hfq_kdata.py +++ b/src/zvt/domain/quotes/stock/stock_1d_hfq_kdata.py @@ -12,7 +12,7 @@ class Stock1dHfqKdata(KdataBase, StockKdataCommon): __tablename__ = "stock_1d_hfq_kdata" -register_schema(providers=["joinquant", "em"], db_name="stock_1d_hfq_kdata", schema_base=KdataBase, entity_type="stock") +register_schema(providers=["em", "joinquant"], db_name="stock_1d_hfq_kdata", schema_base=KdataBase, entity_type="stock") # the __all__ is generated __all__ = ["Stock1dHfqKdata"] diff --git a/src/zvt/domain/quotes/stock/stock_1d_kdata.py b/src/zvt/domain/quotes/stock/stock_1d_kdata.py index 9df5e686..f00272a1 100644 --- a/src/zvt/domain/quotes/stock/stock_1d_kdata.py +++ b/src/zvt/domain/quotes/stock/stock_1d_kdata.py @@ -12,7 +12,7 @@ class Stock1dKdata(KdataBase, StockKdataCommon): __tablename__ = "stock_1d_kdata" -register_schema(providers=["joinquant", "em"], db_name="stock_1d_kdata", schema_base=KdataBase, entity_type="stock") +register_schema(providers=["em", "joinquant"], db_name="stock_1d_kdata", schema_base=KdataBase, entity_type="stock") # the __all__ is generated __all__ = ["Stock1dKdata"] diff --git a/src/zvt/domain/quotes/stock/stock_1h_hfq_kdata.py b/src/zvt/domain/quotes/stock/stock_1h_hfq_kdata.py index 87343684..9ef2c38f 100644 --- a/src/zvt/domain/quotes/stock/stock_1h_hfq_kdata.py +++ b/src/zvt/domain/quotes/stock/stock_1h_hfq_kdata.py @@ -12,7 +12,7 @@ class Stock1hHfqKdata(KdataBase, StockKdataCommon): __tablename__ = "stock_1h_hfq_kdata" -register_schema(providers=["joinquant", "em"], db_name="stock_1h_hfq_kdata", schema_base=KdataBase, entity_type="stock") +register_schema(providers=["em", "joinquant"], db_name="stock_1h_hfq_kdata", schema_base=KdataBase, entity_type="stock") # the __all__ is generated __all__ = ["Stock1hHfqKdata"] diff --git a/src/zvt/domain/quotes/stock/stock_1h_kdata.py b/src/zvt/domain/quotes/stock/stock_1h_kdata.py index 5143b47b..deee018c 100644 --- a/src/zvt/domain/quotes/stock/stock_1h_kdata.py +++ b/src/zvt/domain/quotes/stock/stock_1h_kdata.py @@ -12,7 +12,7 @@ class Stock1hKdata(KdataBase, StockKdataCommon): __tablename__ = "stock_1h_kdata" -register_schema(providers=["joinquant", "em"], db_name="stock_1h_kdata", schema_base=KdataBase, entity_type="stock") +register_schema(providers=["em", "joinquant"], db_name="stock_1h_kdata", schema_base=KdataBase, entity_type="stock") # the __all__ is generated __all__ = ["Stock1hKdata"] diff --git a/src/zvt/domain/quotes/stock/stock_1m_hfq_kdata.py b/src/zvt/domain/quotes/stock/stock_1m_hfq_kdata.py index 800401cb..decb4373 100644 --- a/src/zvt/domain/quotes/stock/stock_1m_hfq_kdata.py +++ b/src/zvt/domain/quotes/stock/stock_1m_hfq_kdata.py @@ -12,7 +12,7 @@ class Stock1mHfqKdata(KdataBase, StockKdataCommon): __tablename__ = "stock_1m_hfq_kdata" -register_schema(providers=["joinquant", "em"], db_name="stock_1m_hfq_kdata", schema_base=KdataBase, entity_type="stock") +register_schema(providers=["em", "joinquant"], db_name="stock_1m_hfq_kdata", schema_base=KdataBase, entity_type="stock") # the __all__ is generated __all__ = ["Stock1mHfqKdata"] diff --git a/src/zvt/domain/quotes/stock/stock_1m_kdata.py b/src/zvt/domain/quotes/stock/stock_1m_kdata.py index c7e8fd82..8bda2d87 100644 --- a/src/zvt/domain/quotes/stock/stock_1m_kdata.py +++ b/src/zvt/domain/quotes/stock/stock_1m_kdata.py @@ -12,7 +12,7 @@ class Stock1mKdata(KdataBase, StockKdataCommon): __tablename__ = "stock_1m_kdata" -register_schema(providers=["joinquant", "em"], db_name="stock_1m_kdata", schema_base=KdataBase, entity_type="stock") +register_schema(providers=["em", "joinquant"], db_name="stock_1m_kdata", schema_base=KdataBase, entity_type="stock") # the __all__ is generated __all__ = ["Stock1mKdata"] diff --git a/src/zvt/domain/quotes/stock/stock_1mon_hfq_kdata.py b/src/zvt/domain/quotes/stock/stock_1mon_hfq_kdata.py index eef6e6ec..cfd3ebe6 100644 --- a/src/zvt/domain/quotes/stock/stock_1mon_hfq_kdata.py +++ b/src/zvt/domain/quotes/stock/stock_1mon_hfq_kdata.py @@ -13,7 +13,7 @@ class Stock1monHfqKdata(KdataBase, StockKdataCommon): register_schema( - providers=["joinquant", "em"], db_name="stock_1mon_hfq_kdata", schema_base=KdataBase, entity_type="stock" + providers=["em", "joinquant"], db_name="stock_1mon_hfq_kdata", schema_base=KdataBase, entity_type="stock" ) # the __all__ is generated diff --git a/src/zvt/domain/quotes/stock/stock_1mon_kdata.py b/src/zvt/domain/quotes/stock/stock_1mon_kdata.py index 91c53947..6bf77c0a 100644 --- a/src/zvt/domain/quotes/stock/stock_1mon_kdata.py +++ b/src/zvt/domain/quotes/stock/stock_1mon_kdata.py @@ -12,7 +12,7 @@ class Stock1monKdata(KdataBase, StockKdataCommon): __tablename__ = "stock_1mon_kdata" -register_schema(providers=["joinquant", "em"], db_name="stock_1mon_kdata", schema_base=KdataBase, entity_type="stock") +register_schema(providers=["em", "joinquant"], db_name="stock_1mon_kdata", schema_base=KdataBase, entity_type="stock") # the __all__ is generated __all__ = ["Stock1monKdata"] diff --git a/src/zvt/domain/quotes/stock/stock_1wk_hfq_kdata.py b/src/zvt/domain/quotes/stock/stock_1wk_hfq_kdata.py index c91a5b34..930c2a14 100644 --- a/src/zvt/domain/quotes/stock/stock_1wk_hfq_kdata.py +++ b/src/zvt/domain/quotes/stock/stock_1wk_hfq_kdata.py @@ -13,7 +13,7 @@ class Stock1wkHfqKdata(KdataBase, StockKdataCommon): register_schema( - providers=["joinquant", "em"], db_name="stock_1wk_hfq_kdata", schema_base=KdataBase, entity_type="stock" + providers=["em", "joinquant"], db_name="stock_1wk_hfq_kdata", schema_base=KdataBase, entity_type="stock" ) # the __all__ is generated diff --git a/src/zvt/domain/quotes/stock/stock_1wk_kdata.py b/src/zvt/domain/quotes/stock/stock_1wk_kdata.py index 87180991..7a17db75 100644 --- a/src/zvt/domain/quotes/stock/stock_1wk_kdata.py +++ b/src/zvt/domain/quotes/stock/stock_1wk_kdata.py @@ -12,7 +12,7 @@ class Stock1wkKdata(KdataBase, StockKdataCommon): __tablename__ = "stock_1wk_kdata" -register_schema(providers=["joinquant", "em"], db_name="stock_1wk_kdata", schema_base=KdataBase, entity_type="stock") +register_schema(providers=["em", "joinquant"], db_name="stock_1wk_kdata", schema_base=KdataBase, entity_type="stock") # the __all__ is generated __all__ = ["Stock1wkKdata"] diff --git a/src/zvt/domain/quotes/stock/stock_30m_hfq_kdata.py b/src/zvt/domain/quotes/stock/stock_30m_hfq_kdata.py index 9f6f01f7..8dd005e3 100644 --- a/src/zvt/domain/quotes/stock/stock_30m_hfq_kdata.py +++ b/src/zvt/domain/quotes/stock/stock_30m_hfq_kdata.py @@ -13,7 +13,7 @@ class Stock30mHfqKdata(KdataBase, StockKdataCommon): register_schema( - providers=["joinquant", "em"], db_name="stock_30m_hfq_kdata", schema_base=KdataBase, entity_type="stock" + providers=["em", "joinquant"], db_name="stock_30m_hfq_kdata", schema_base=KdataBase, entity_type="stock" ) # the __all__ is generated diff --git a/src/zvt/domain/quotes/stock/stock_30m_kdata.py b/src/zvt/domain/quotes/stock/stock_30m_kdata.py index 65ba4c0e..34aa5925 100644 --- a/src/zvt/domain/quotes/stock/stock_30m_kdata.py +++ b/src/zvt/domain/quotes/stock/stock_30m_kdata.py @@ -12,7 +12,7 @@ class Stock30mKdata(KdataBase, StockKdataCommon): __tablename__ = "stock_30m_kdata" -register_schema(providers=["joinquant", "em"], db_name="stock_30m_kdata", schema_base=KdataBase, entity_type="stock") +register_schema(providers=["em", "joinquant"], db_name="stock_30m_kdata", schema_base=KdataBase, entity_type="stock") # the __all__ is generated __all__ = ["Stock30mKdata"] diff --git a/src/zvt/domain/quotes/stock/stock_4h_hfq_kdata.py b/src/zvt/domain/quotes/stock/stock_4h_hfq_kdata.py index 0831599a..dac13883 100644 --- a/src/zvt/domain/quotes/stock/stock_4h_hfq_kdata.py +++ b/src/zvt/domain/quotes/stock/stock_4h_hfq_kdata.py @@ -12,7 +12,7 @@ class Stock4hHfqKdata(KdataBase, StockKdataCommon): __tablename__ = "stock_4h_hfq_kdata" -register_schema(providers=["joinquant", "em"], db_name="stock_4h_hfq_kdata", schema_base=KdataBase, entity_type="stock") +register_schema(providers=["em", "joinquant"], db_name="stock_4h_hfq_kdata", schema_base=KdataBase, entity_type="stock") # the __all__ is generated __all__ = ["Stock4hHfqKdata"] diff --git a/src/zvt/domain/quotes/stock/stock_4h_kdata.py b/src/zvt/domain/quotes/stock/stock_4h_kdata.py index d65f96ef..271f5f2f 100644 --- a/src/zvt/domain/quotes/stock/stock_4h_kdata.py +++ b/src/zvt/domain/quotes/stock/stock_4h_kdata.py @@ -12,7 +12,7 @@ class Stock4hKdata(KdataBase, StockKdataCommon): __tablename__ = "stock_4h_kdata" -register_schema(providers=["joinquant", "em"], db_name="stock_4h_kdata", schema_base=KdataBase, entity_type="stock") +register_schema(providers=["em", "joinquant"], db_name="stock_4h_kdata", schema_base=KdataBase, entity_type="stock") # the __all__ is generated __all__ = ["Stock4hKdata"] diff --git a/src/zvt/domain/quotes/stock/stock_5m_hfq_kdata.py b/src/zvt/domain/quotes/stock/stock_5m_hfq_kdata.py index 68812a2b..2f2a9769 100644 --- a/src/zvt/domain/quotes/stock/stock_5m_hfq_kdata.py +++ b/src/zvt/domain/quotes/stock/stock_5m_hfq_kdata.py @@ -12,7 +12,7 @@ class Stock5mHfqKdata(KdataBase, StockKdataCommon): __tablename__ = "stock_5m_hfq_kdata" -register_schema(providers=["joinquant", "em"], db_name="stock_5m_hfq_kdata", schema_base=KdataBase, entity_type="stock") +register_schema(providers=["em", "joinquant"], db_name="stock_5m_hfq_kdata", schema_base=KdataBase, entity_type="stock") # the __all__ is generated __all__ = ["Stock5mHfqKdata"] diff --git a/src/zvt/domain/quotes/stock/stock_5m_kdata.py b/src/zvt/domain/quotes/stock/stock_5m_kdata.py index 5d69e232..97144f9f 100644 --- a/src/zvt/domain/quotes/stock/stock_5m_kdata.py +++ b/src/zvt/domain/quotes/stock/stock_5m_kdata.py @@ -12,7 +12,7 @@ class Stock5mKdata(KdataBase, StockKdataCommon): __tablename__ = "stock_5m_kdata" -register_schema(providers=["joinquant", "em"], db_name="stock_5m_kdata", schema_base=KdataBase, entity_type="stock") +register_schema(providers=["em", "joinquant"], db_name="stock_5m_kdata", schema_base=KdataBase, entity_type="stock") # the __all__ is generated __all__ = ["Stock5mKdata"] diff --git a/src/zvt/domain/quotes/stock/stock_l2quote_hfq_kdata.py b/src/zvt/domain/quotes/stock/stock_l2quote_hfq_kdata.py new file mode 100644 index 00000000..0ebf935e --- /dev/null +++ b/src/zvt/domain/quotes/stock/stock_l2quote_hfq_kdata.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +# this file is generated by gen_kdata_schema function, dont't change it +from sqlalchemy.orm import declarative_base + +from zvt.contract.register import register_schema +from zvt.domain.quotes import StockKdataCommon + +KdataBase = declarative_base() + + +class StockL2quoteHfqKdata(KdataBase, StockKdataCommon): + __tablename__ = "stock_l2quote_hfq_kdata" + + +register_schema( + providers=["em", "joinquant"], db_name="stock_l2quote_hfq_kdata", schema_base=KdataBase, entity_type="stock" +) + +# the __all__ is generated +__all__ = ["StockL2quoteHfqKdata"] diff --git a/src/zvt/domain/quotes/stock/stock_l2quote_kdata.py b/src/zvt/domain/quotes/stock/stock_l2quote_kdata.py new file mode 100644 index 00000000..e93772fc --- /dev/null +++ b/src/zvt/domain/quotes/stock/stock_l2quote_kdata.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +# this file is generated by gen_kdata_schema function, dont't change it +from sqlalchemy.orm import declarative_base + +from zvt.contract.register import register_schema +from zvt.domain.quotes import StockKdataCommon + +KdataBase = declarative_base() + + +class StockL2quoteKdata(KdataBase, StockKdataCommon): + __tablename__ = "stock_l2quote_kdata" + + +register_schema( + providers=["em", "joinquant"], db_name="stock_l2quote_kdata", schema_base=KdataBase, entity_type="stock" +) + +# the __all__ is generated +__all__ = ["StockL2quoteKdata"] diff --git a/src/zvt/fill_project.py b/src/zvt/fill_project.py index cf73967f..d62ddc39 100644 --- a/src/zvt/fill_project.py +++ b/src/zvt/fill_project.py @@ -12,7 +12,7 @@ def gen_kdata_schemas(): # A股行情 gen_kdata_schema( pkg="zvt", - providers=["joinquant", "em"], + providers=["em", "joinquant"], entity_type="stock", levels=[level for level in IntervalLevel if level != IntervalLevel.LEVEL_TICK], adjust_types=[None, AdjustType.hfq], @@ -78,7 +78,7 @@ def gen_kdata_schemas(): pkg="zvt", providers=["sina"], entity_type="etf", levels=[IntervalLevel.LEVEL_1DAY], entity_in_submodule=True ) - # etf行情 + # currency行情 gen_kdata_schema( pkg="zvt", providers=["em"], entity_type="currency", levels=[IntervalLevel.LEVEL_1DAY], entity_in_submodule=True ) @@ -95,5 +95,5 @@ def gen_kdata_schemas(): # gen_exports('autocode') # gen_exports("ml") # gen_kdata_schemas() - gen_exports("recorders") - # gen_exports("domain") + # gen_exports("recorders") + gen_exports("tag") diff --git a/src/zvt/recorders/em/em_api.py b/src/zvt/recorders/em/em_api.py index 4abef240..e91eaac9 100644 --- a/src/zvt/recorders/em/em_api.py +++ b/src/zvt/recorders/em/em_api.py @@ -292,10 +292,10 @@ def get_em_data( resp.close() if json_result: - if "result" in json_result: + if json_result.get("result"): data: list = json_result["result"]["data"] need_next = pn < json_result["result"]["pages"] - elif "data" in json_result: + elif json_result.get("data"): data: list = json_result["data"] need_next = json_result["hasNext"] == 1 else: @@ -824,8 +824,9 @@ def to_zvt_code(code): # pprint(get_free_holders(code='000338', end_date='2021-03-31')) # pprint(get_ii_holder(code='000338', report_date='2021-03-31', # org_type=actor_type_to_org_type(ActorType.corporation))) - # pprint(get_ii_summary(code='000338', report_date='2021-03-31', - # org_type=actor_type_to_org_type(ActorType.corporation))) + print( + get_ii_summary(code="600519", report_date="2021-03-31", org_type=actor_type_to_org_type(ActorType.corporation)) + ) # df = get_kdata(entity_id="index_sz_399370", level="1wk") # df = get_tradable_list(entity_type="stockhk") # df = get_news("stock_sz_300999") @@ -854,8 +855,8 @@ def to_zvt_code(code): # df = get_kdata(entity_id="stock_bj_873693", level="1d") # print(df) # print(get_controlling_shareholder(code="000338")) - events = get_events(entity_id="stock_sz_300684") - print(events) + # events = get_events(entity_id="stock_sz_300684") + # print(events) # the __all__ is generated __all__ = [ diff --git a/src/zvt/ui/apps/__init__.py b/src/zvt/rest/__init__.py similarity index 100% rename from src/zvt/ui/apps/__init__.py rename to src/zvt/rest/__init__.py diff --git a/src/zvt/rest/data.py b/src/zvt/rest/data.py new file mode 100644 index 00000000..d4b55acf --- /dev/null +++ b/src/zvt/rest/data.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +from fastapi import APIRouter +from fastapi.encoders import jsonable_encoder + +import zvt.contract.api as contract_api +import zvt.contract as contract + +data_router = APIRouter( + prefix="/api/data", + tags=["data"], + responses={404: {"description": "Not found"}}, +) + + +@data_router.get( + "/providers", + response_model=list, +) +def get_data_providers(): + """ + Get data providers + """ + return contract_api.get_providers() + + +@data_router.get( + "/schemas", + response_model=list, +) +def get_data_schemas(provider): + """ + Get schemas by provider + """ + return [schema.__name__ for schema in contract_api.get_schemas(provider=provider)] + + +@data_router.get( + "/query_data", + response_model=list, +) +def query_data(provider: str, schema: str): + """ + Get schemas by provider + """ + model: contract.Mixin = contract_api.get_schema_by_name(schema) + with contract_api.DBSession(provider=provider, data_schema=model)() as session: + return jsonable_encoder(model.query_data(session=session, limit=100, return_type="domain")) diff --git a/src/zvt/rest/work.py b/src/zvt/rest/work.py new file mode 100644 index 00000000..55ed9323 --- /dev/null +++ b/src/zvt/rest/work.py @@ -0,0 +1,172 @@ +# -*- coding: utf-8 -*- +from typing import List + +from fastapi import APIRouter, Query, HTTPException + +from zvt.contract.api import get_db_session +from zvt.tag.dataset.stock_tags import ( + StockTagsModel, + StockTags, + CreateStockTagsModel, + TagsInfoModel, + TagsInfo, + CreateTagsInfoModel, + SubTagsInfo, + HiddenTagsInfo, +) +from zvt.utils import is_same_date, today, current_date, to_time_str +from zvt.utils.model_utils import update_model +import zvt.contract.api as contract_api + +work_router = APIRouter( + prefix="/api/work", + tags=["work"], + responses={404: {"description": "Not found"}}, +) + + +@work_router.get("/get_tags_info", response_model=List[TagsInfoModel]) +def get_tags_info(): + """ + Get tags info + """ + with contract_api.DBSession(provider="zvt", data_schema=TagsInfo)() as session: + tags_info: List[TagsInfo] = TagsInfo.query_data(session=session, return_type="domain") + return tags_info + + +@work_router.get("/get_sub_tags_info", response_model=List[TagsInfoModel]) +def get_sub_tags_info(): + """ + Get sub tags info + """ + with contract_api.DBSession(provider="zvt", data_schema=SubTagsInfo)() as session: + tags_info: List[SubTagsInfo] = SubTagsInfo.query_data(session=session, return_type="domain") + return tags_info + + +@work_router.get("/get_hidden_tags_info", response_model=List[TagsInfoModel]) +def get_tags_info(): + """ + Get hidden tags info + """ + with contract_api.DBSession(provider="zvt", data_schema=TagsInfo)() as session: + tags_info: List[HiddenTagsInfo] = HiddenTagsInfo.query_data(session=session, return_type="domain") + return tags_info + + +def _create_tags_info(tags_info: CreateTagsInfoModel, tags_info_type="tags_info"): + """ + Create tags info + """ + if tags_info_type == "tags_info": + data_schema = TagsInfo + elif tags_info_type == "sub_tags_info": + data_schema = SubTagsInfo + elif tags_info_type == "hidden_tags_info": + data_schema = HiddenTagsInfo + else: + assert False + + with contract_api.DBSession(provider="zvt", data_schema=data_schema)() as session: + current_tags_info = data_schema.query_data( + session=session, filters=[data_schema.tag == tags_info.tag], return_type="domain" + ) + if current_tags_info: + raise HTTPException(status_code=409, detail=f"This tag has been registered in {tags_info_type}") + timestamp = current_date() + entity_id = "admin" + tags_info_db = data_schema( + id=f"admin_{to_time_str(timestamp)}_{tags_info.tag}", + entity_id=entity_id, + timestamp=timestamp, + tag=tags_info.tag, + tag_reason=tags_info.tag_reason, + ) + session.add(tags_info_db) + session.commit() + session.refresh(tags_info_db) + return tags_info_db + + +@work_router.post("/create_tags_info", response_model=TagsInfoModel) +def create_tags_info(tags_info: CreateTagsInfoModel): + return _create_tags_info(tags_info, tags_info_type="tags_info") + + +@work_router.post("/create_sub_tags_info", response_model=TagsInfoModel) +def create_sub_tags_info(tags_info: CreateTagsInfoModel): + return _create_tags_info(tags_info, tags_info_type="sub_tags_info") + + +@work_router.post("/create_hidden_tags_info", response_model=TagsInfoModel) +def create_hidden_tags_info(tags_info: CreateTagsInfoModel): + return _create_tags_info(tags_info, tags_info_type="hidden_tags_info") + + +@work_router.get("/get_stock_tags", response_model=List[StockTagsModel]) +def get_stock_tags(entity_ids: str = Query(None), history_data: bool = Query(False)): + """ + Get entity tags + """ + filters = [] + if not history_data: + filters = filters + [StockTags.latest.is_(True)] + if entity_ids: + filters = filters + [StockTags.entity_id.in_(entity_ids.split(","))] + with contract_api.DBSession(provider="zvt", data_schema=StockTags)() as session: + tags: List[StockTags] = StockTags.query_data( + session=session, filters=filters, return_type="domain", order=StockTags.timestamp.desc() + ) + return tags + + +@work_router.post("/create_stock_tags", response_model=StockTagsModel) +def create_stock_tags(stock_tags: CreateStockTagsModel): + """ + Set entity tags + """ + with contract_api.DBSession(provider="zvt", data_schema=StockTags)() as session: + entity_id = stock_tags.entity_id + tags = {} + sub_tags = {} + timestamp = current_date() + datas = StockTags.query_data( + session=session, entity_id=entity_id, order=StockTags.timestamp.desc(), limit=1, return_type="domain" + ) + if datas: + current_stock_tags: StockTags = datas[0] + if current_stock_tags.tags: + tags = dict(current_stock_tags.tags) + if current_stock_tags.sub_tags: + sub_tags = dict(current_stock_tags.sub_tags) + if is_same_date(current_stock_tags.timestamp, timestamp): + stock_tags_db = current_stock_tags + else: + current_stock_tags.latest = False + stock_tags_db = StockTags( + id=f"{entity_id}_{to_time_str(timestamp)}", + entity_id=entity_id, + timestamp=timestamp, + ) + session.add(current_stock_tags) + else: + stock_tags_db = StockTags( + id=f"{entity_id}_{to_time_str(timestamp)}", + entity_id=entity_id, + timestamp=timestamp, + ) + + update_model(stock_tags_db, stock_tags) + tags[stock_tags.tag] = stock_tags.tag_reason + if stock_tags.sub_tag: + sub_tags[stock_tags.sub_tag] = stock_tags.sub_tag_reason + + stock_tags_db.latest = True + # update + stock_tags_db.tags = tags + stock_tags_db.sub_tags = sub_tags + session.add(stock_tags_db) + session.commit() + session.refresh(stock_tags_db) + return stock_tags_db diff --git a/src/zvt/server.py b/src/zvt/server.py new file mode 100644 index 00000000..6b9aeda2 --- /dev/null +++ b/src/zvt/server.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +import uvicorn +from fastapi import FastAPI + +from zvt.rest.data import data_router +from zvt.rest.work import work_router + +app = FastAPI() + + +@app.get("/") +async def root(): + return {"message": "Hello World"} + + +app.include_router(data_router) +app.include_router(work_router) + +if __name__ == "__main__": + uvicorn.run("server:app", host="0.0.0.0", reload=True, port=8080) diff --git a/src/zvt/tag/__init__.py b/src/zvt/tag/__init__.py index e75bf1cd..998df29d 100644 --- a/src/zvt/tag/__init__.py +++ b/src/zvt/tag/__init__.py @@ -6,20 +6,26 @@ # common code of the package # export interface in __all__ which contains __all__ of its sub modules +# import all from submodule tag_utils +from .tag_utils import * +from .tag_utils import __all__ as _tag_utils_all + +__all__ += _tag_utils_all + # import all from submodule dataset from .dataset import * from .dataset import __all__ as _dataset_all __all__ += _dataset_all -# import all from submodule tags -from .tags import * -from .tags import __all__ as _tags_all +# import all from submodule tagger +from .tagger import * +from .tagger import __all__ as _tagger_all -__all__ += _tags_all +__all__ += _tagger_all -# import all from submodule tag -from .tag import * -from .tag import __all__ as _tag_all +# import all from submodule stock_auto_tagger +from .stock_auto_tagger import * +from .stock_auto_tagger import __all__ as _stock_auto_tagger_all -__all__ += _tag_all +__all__ += _stock_auto_tagger_all diff --git a/src/zvt/tag/dataset/stock_tags.py b/src/zvt/tag/dataset/stock_tags.py index 1e6bb93e..09995b3c 100644 --- a/src/zvt/tag/dataset/stock_tags.py +++ b/src/zvt/tag/dataset/stock_tags.py @@ -1,13 +1,48 @@ # -*- coding: utf-8 -*- -from sqlalchemy import Column, String +from typing import Dict, List, Union + +from pydantic import BaseModel +from sqlalchemy import Column, String, JSON, Boolean from sqlalchemy.orm import declarative_base from zvt.contract import Mixin +from zvt.contract.model import MixinModel from zvt.contract.register import register_schema StockTagsBase = declarative_base() +class TagsInfo(StockTagsBase, Mixin): + __tablename__ = "tags_info" + + tag = Column(String, unique=True) + tag_reason = Column(String) + + +class SubTagsInfo(StockTagsBase, Mixin): + __tablename__ = "sub_tags_info" + + tag = Column(String, unique=True) + tag_reason = Column(String) + + +class HiddenTagsInfo(StockTagsBase, Mixin): + __tablename__ = "hidden_tags_info" + + tag = Column(String, unique=True) + tag_reason = Column(String) + + +class TagsInfoModel(MixinModel): + tag: str + tag_reason: str + + +class CreateTagsInfoModel(BaseModel): + tag: str + tag_reason: str + + class StockTags(StockTagsBase, Mixin): """ Schema for storing stock tags @@ -15,16 +50,54 @@ class StockTags(StockTagsBase, Mixin): __tablename__ = "stock_tags" - #: :class:`~.zvt.tag.tags.actor_tag.ActorTag` values - actor_tag = Column(String(length=64)) - #: :class:`~.zvt.tag.tags.style_tag.StyleTag` values - style_tag = Column(String(length=64)) - #: :class:`~.zvt.tag.tags.cycle_tag.CycleTag` values - cycle_tag = Column(String(length=64)) - #: :class:`~.zvt.tag.tags.market_value_tag.MarketValueTag` values - market_value_tag = Column(String(length=64)) + tag = Column(String) + tag_reason = Column(String) + tags = Column(JSON) + + sub_tag = Column(String) + sub_tag_reason = Column(String) + sub_tags = Column(JSON) + + active_hidden_tags = Column(JSON) + hidden_tags = Column(JSON) + set_by_user = Column(Boolean, default=False) + #: 每个标的最后更新的记录 + latest = Column(Boolean, default=False) + + +class StockTagsModel(MixinModel): + tag: str + tag_reason: str + tags: Dict[str, str] + + sub_tag: Union[str, None] + sub_tag_reason: Union[str, None] + sub_tags: Union[Dict[str, str], None] + + active_hidden_tags: Union[List[str], None] + hidden_tags: Union[Dict[str, str], None] + set_by_user: bool = False + + +class CreateStockTagsModel(BaseModel): + entity_id: str + tag: str + tag_reason: str + sub_tag: Union[str, None] + sub_tag_reason: Union[str, None] + active_hidden_tags: Union[List[str], None] + hidden_tags: Union[Dict[str, str], None] register_schema(providers=["zvt"], db_name="stock_tags", schema_base=StockTagsBase) # the __all__ is generated -__all__ = ["StockTags"] +__all__ = [ + "TagsInfo", + "SubTagsInfo", + "HiddenTagsInfo", + "TagsInfoModel", + "CreateTagsInfoModel", + "StockTags", + "StockTagsModel", + "CreateStockTagsModel", +] diff --git a/src/zvt/tag/stock_auto_tagger.py b/src/zvt/tag/stock_auto_tagger.py new file mode 100644 index 00000000..65a35d49 --- /dev/null +++ b/src/zvt/tag/stock_auto_tagger.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- +import logging +from enum import Enum +from typing import List + +import pandas as pd +import sqlalchemy +from zvt.api import get_recent_report +from zvt.api.selector import get_entity_ids_by_filter +from zvt.contract import ActorType +from zvt.contract.api import get_db_session, df_to_db +from zvt.domain import StockActorSummary, BlockStock, Block, Stock, StockEvents +from zvt.tag.dataset.stock_tags import StockTags, TagsInfo, SubTagsInfo, HiddenTagsInfo +from zvt.tag.tag_utils import industry_to_tag, get_concept_list, get_tags_info, get_sub_tags_info, get_hidden_tags_info +from zvt.tag.tagger import StockTagger +from zvt.utils import pd_is_not_null, is_same_date, count_interval, to_time_str, to_pd_timestamp + +logger = logging.getLogger(__name__) + + +def build_tags_info(): + tags_info = get_tags_info() + df = pd.DataFrame.from_records(tags_info) + df_to_db(df=df, data_schema=TagsInfo, provider="zvt", force_update=False) + + +def build_sub_tags_info(): + tags_info = get_sub_tags_info() + df = pd.DataFrame.from_records(tags_info) + df_to_db(df=df, data_schema=SubTagsInfo, provider="zvt", force_update=False) + + +def build_hidden_tags_info(): + tags_info = get_hidden_tags_info() + df = pd.DataFrame.from_records(tags_info) + df_to_db(df=df, data_schema=HiddenTagsInfo, provider="zvt", force_update=False) + + +class StockAutoTagger(StockTagger): + def __init__(self, force=False) -> None: + super().__init__(force) + datas = StockTags.query_data(limit=1, return_type="domain") + self.stock_tags_map = {} + if not datas: + self.init_default_tags(entity_ids=None) + + def init_default_tags(self, entity_ids=None): + stocks = Stock.query_data( + provider="em", entity_ids=entity_ids, return_type="domain", filters=[Stock.timestamp.isnot(None)] + ) + entity_map = {stock.entity_id: stock for stock in stocks} + entity_ids = entity_map.keys() + df_block = Block.query_data(provider="em", filters=[Block.category == "industry"]) + industry_codes = df_block["code"].tolist() + block_stocks: List[BlockStock] = BlockStock.query_data( + provider="em", + filters=[BlockStock.code.in_(industry_codes), BlockStock.stock_id.in_(entity_ids)], + return_type="domain", + ) + stock_tags = [] + for block_stock in block_stocks: + entity = entity_map.get(block_stock.stock_id) + if entity.id in self.stock_tags_map: + print(f"ignore {entity.id}") + continue + print(f"to {entity.id}") + + tag = industry_to_tag(industry=block_stock.name) + tag_reason = f"来自行业:{block_stock.name}" + + stock_tag = StockTags( + id=f"{entity.id}_{to_time_str(entity.timestamp)}", + entity_id=block_stock.stock_id, + timestamp=entity.timestamp, + tag=tag, + tag_reason=tag_reason, + tags={tag: tag_reason}, + set_by_user=False, + latest=True, + ) + stock_tags.append(stock_tag) + self.stock_tags_map[entity.id] = stock_tag + self.session.add_all(stock_tags) + self.session.commit() + + def build_sub_tags(self, entity_ids): + for entity_id in entity_ids: + print(f"to {entity_id}") + datas = StockTags.query_data( + entity_id=entity_id, order=StockTags.timestamp.desc(), limit=1, return_type="domain" + ) + if not datas: + print(f"ignore no tag:{entity_id}") + continue + start_timestamp = max(datas[0].timestamp, to_pd_timestamp("2005-01-01")) + print(f"start_timestamp: {start_timestamp}") + current_sub_tag = datas[0].sub_tag + filters = [StockEvents.event_type == "新增概念", StockEvents.entity_id == entity_id] + if current_sub_tag: + print(f"current_sub_tag: {current_sub_tag}") + filters = filters + [sqlalchemy.not_(StockEvents.level1_content.contains(current_sub_tag))] + stock_events: List[StockEvents] = StockEvents.query_data( + provider="em", + start_timestamp=start_timestamp, + filters=filters, + order=StockEvents.timestamp.asc(), + return_type="domain", + ) + if not stock_events: + print(f"no event for {entity_id}") + + for stock_event in stock_events: + datas: List[StockTags] = StockTags.query_data( + entity_id=entity_id, order=StockTags.timestamp.desc(), limit=1, return_type="domain" + ) + current_stock_tags: StockTags = datas[0] + + event_timestamp = to_pd_timestamp(stock_event.timestamp) + if stock_event.level1_content: + contents = stock_event.level1_content.split(":") + if len(contents) < 2: + print(f"wrong stock_event:{stock_event.level1_content}") + else: + sub_tag = contents[1] + if sub_tag in get_concept_list(return_checked=True): + if stock_event.level2_content: + sub_tag_reason = stock_event.level2_content.split(":")[1] + else: + sub_tag_reason = f"来自概念:{sub_tag}" + + if current_stock_tags.sub_tags: + sub_tags = dict(current_stock_tags.sub_tags) + else: + sub_tags = {} + if is_same_date(event_timestamp, current_stock_tags.timestamp): + stock_tag = current_stock_tags + else: + current_stock_tags.latest = False + stock_tag = StockTags( + id=f"{entity_id}_{to_time_str(event_timestamp)}", + entity_id=entity_id, + timestamp=event_timestamp, + tag=current_stock_tags.tag, + tag_reason=current_stock_tags.tag_reason, + tags=current_stock_tags.tags, + ) + sub_tags[sub_tag] = sub_tag_reason + stock_tag.sub_tag = sub_tag + stock_tag.sub_tag_reason = sub_tag_reason + stock_tag.sub_tags = sub_tags + stock_tag.latest = True + self.session.add(stock_tag) + self.session.commit() + else: + print(f"ignore concept: {sub_tag}") + + def tag(self): + entity_ids = get_entity_ids_by_filter( + provider="em", ignore_delist=False, ignore_st=False, ignore_new_stock=False + ) + for entity_id in entity_ids: + stock_tags = StockTags.query_data( + entity_id=entity_id, order=StockTags.timestamp.desc(), limit=1, return_type="domain" + ) + if not stock_tags: + self.init_default_tags(entity_ids=[entity_id]) + self.build_sub_tags(entity_ids=entity_ids) + + +if __name__ == "__main__": + # entity_ids = get_entity_ids_by_filter( + # provider="em", ignore_delist=False, ignore_st=False, ignore_new_stock=False + # ) + # session = get_db_session(provider="zvt", data_schema=StockTags) + # for entity_id in entity_ids: + # datas = StockTags.query_data( + # entity_id=entity_id, order=StockTags.timestamp.desc(), limit=1, return_type="domain" + # ) + # if datas: + # stock_tags = datas[0] + # stock_tags.latest=True + # session.add(stock_tags) + # session.commit() + # StockAutoTagger().tag() + build_tags_info() + build_sub_tags_info() + build_hidden_tags_info() + # from sqlalchemy import func + # + # df = StockTags.query_data( + # filters=[func.json_extract(StockTags.sub_tags, '$."麒麟电池"') != None], columns=[StockTags.sub_tags] + # ) + # print(df) +# the __all__ is generated +__all__ = ["build_tags_info", "build_sub_tags_info", "build_hidden_tags_info", "StockAutoTagger"] diff --git a/src/zvt/tag/tag.py b/src/zvt/tag/tag.py deleted file mode 100644 index f94e60b4..00000000 --- a/src/zvt/tag/tag.py +++ /dev/null @@ -1,106 +0,0 @@ -# -*- coding: utf-8 -*- -import logging -from typing import Type - -import pandas as pd - -from zvt.contract import Mixin -from zvt.contract import TradableEntity -from zvt.contract.api import get_db_session -from zvt.contract.base_service import OneStateService -from zvt.contract.zvt_info import TaggerState -from zvt.domain import Stock, Index -from zvt.tag.dataset.stock_tags import StockTags -from zvt.utils import to_time_str, to_pd_timestamp -from zvt.utils.time_utils import TIME_FORMAT_DAY, now_pd_timestamp - -logger = logging.getLogger(__name__) - - -class Tagger(OneStateService): - state_schema = TaggerState - - entity_schema: Type[TradableEntity] = None - - data_schema: Type[Mixin] = None - - start_timestamp = "2005-01-01" - - def __init__(self, force=False) -> None: - super().__init__() - assert self.entity_schema is not None - assert self.data_schema is not None - self.force = force - self.session = get_db_session(provider="zvt", data_schema=self.data_schema) - if self.state and not self.force: - logger.info(f"get start_timestamp from state") - self.start_timestamp = self.state["current_timestamp"] - logger.info(f"tag start_timestamp: {self.start_timestamp}") - - def tag(self, timestamp): - raise NotImplementedError - - def get_tag_timestamps(self): - return pd.date_range(start=self.start_timestamp, end=now_pd_timestamp(), freq="M") - - def get_tag_domain(self, entity_id, timestamp, **fill_kv): - the_date = to_time_str(timestamp, fmt=TIME_FORMAT_DAY) - the_id = f"{entity_id}_{the_date}" - the_domain = self.data_schema.get_by_id(id=the_id) - - if the_domain: - for k, v in fill_kv.items(): - exec(f"the_domain.{k}=v") - else: - return self.data_schema(id=the_id, entity_id=entity_id, timestamp=to_pd_timestamp(the_date), **fill_kv) - return the_domain - - def get_tag_domains(self, entity_ids, timestamp, **fill_kv): - the_date = to_time_str(timestamp, fmt=TIME_FORMAT_DAY) - ids = [f"{entity_id}_{the_date}" for entity_id in entity_ids] - - the_domains = self.data_schema.query_data(ids=ids, return_type="domain") - - if the_domains: - for the_domain in the_domains: - for k, v in fill_kv.items(): - exec(f"the_domain.{k}=v") - - current_ids = [item.id for item in the_domains] - need_new_ids = set(ids) - set(current_ids) - new_domains = [ - self.data_schema( - id=f"{entity_id}_{the_date}", entity_id=entity_id, timestamp=to_pd_timestamp(the_date), **fill_kv - ) - for entity_id in need_new_ids - ] - return the_domains + new_domains - - def run(self): - timestamps = self.get_tag_timestamps() - for timestamp in timestamps: - logger.info(f"tag to {timestamp}") - self.tag(timestamp=timestamp) - - self.state = {"current_timestamp": to_time_str(timestamp)} - self.persist_state() - - -class StockTagger(Tagger): - data_schema = StockTags - entity_schema = Stock - - def tag(self, timestamp): - raise NotImplementedError - - -class IndexTagger(Tagger): - data_schema = StockTags - entity_schema = Index - - def tag(self, timestamp): - raise NotImplementedError - - -# the __all__ is generated -__all__ = ["Tagger", "StockTagger", "IndexTagger"] diff --git a/src/zvt/tag/tag_utils.py b/src/zvt/tag/tag_utils.py new file mode 100644 index 00000000..9a6b6385 --- /dev/null +++ b/src/zvt/tag/tag_utils.py @@ -0,0 +1,380 @@ +# -*- coding: utf-8 -*- +import json +import os +from collections import Counter + +from zvt.api import china_stock_code_to_id, get_china_exchange +from zvt.api.selector import get_entity_ids_by_filter +from zvt.api.tag import tag_stock, get_stock_tags, get_limit_up_reasons +from zvt.contract.api import get_entities +from zvt.domain import BlockStock, Block, Stock +from zvt.utils import current_date, next_date, pd_is_not_null, to_time_str + + +def get_industry_list(): + df = Block.query_data( + filters=[Block.category == "industry"], columns=[Block.name], return_type="df", order=Block.timestamp.desc() + ) + return df["name"].tolist() + + +def get_concept_list(return_checked=True): + checked_list = [ + "低空经济", + "AIPC", + "可控核聚变", + "AI手机", + "英伟达概念", + "Sora概念", + "飞行汽车(eVTOL)", + "PEEK材料概念", + "小米汽车", + "多模态AI", + "高带宽内存", + "短剧互动游戏", + "新型工业化", + "星闪概念", + "BC电池", + "SPD概念", + "减肥药", + "机器人执行器", + "高压快充", + "空间计算", + "裸眼3D", + "混合现实", + "存储芯片", + "液冷概念", + "光通信模块", + "数据要素", + "算力概念", + "MLOps概念", + "时空大数据", + "同步磁阻电机", + "数字水印", + "AI芯片", + "CPO概念", + "ChatGPT概念", + "电子后视镜", + "毫米波概念", + "蒙脱石散", + "血氧仪", + "第四代半导体", + "熊去氧胆酸", + "PLC概念", + "数据确权", + "抗原检测", + "抗菌面料", + "跨境电商", + "人造太阳", + "复合集流体", + "AIGC概念", + "Web3.0", + "供销社概念", + "创新药", + "信创", + "熔盐储能", + "Chiplet概念", + "减速器", + "轮毂电机", + "TOPCon电池", + "光伏高速公路", + "钒电池", + "钙钛矿电池", + "汽车一体化压铸", + "麒麟电池", + "机器人概念", + "汽车热管理", + "F5G概念", + "超超临界发电", + "千金藤素", + "新型城镇化", + "户外露营", + "肝炎概念", + "统一大市场", + "托育服务", + "跨境支付", + "气溶胶检测", + "东数西算", + "重组蛋白", + "数字经济", + "新冠检测", + "新冠药物", + "地下管网", + "虚拟数字人", + "DRG/DIP", + "动力电池回收", + "EDR概念", + "IGBT概念", + "数据安全", + "预制菜概念", + "培育钻石", + "发电机概念", + "华为欧拉", + "磷化工", + "国资云概念", + "元宇宙概念", + "碳基材料", + "工业母机", + "抽水蓄能", + "激光雷达", + "NFT概念", + "机器视觉", + "华为昇腾", + "空间站概念", + "储能", + "钠离子电池", + "盐湖提锂", + "CAR-T细胞疗法", + "换电概念", + "华为汽车", + "核污染防治", + "工业气体", + "光伏建筑一体化", + "碳化硅", + "被动元件", + "磁悬浮概念", + "汽车芯片", + "固态电池", + "碳交易", + "6G概念", + "虚拟电厂", + "鸿蒙概念", + "第三代半导体", + "刀片电池", + "氦气概念", + "MicroLED", + "EDA概念", + "装配建筑", + "汽车拆解", + "疫苗冷链", + "辅助生殖", + "长寿药", + "中芯概念", + "免税概念", + "北交所概念", + "数据中心", + "天基互联", + "特高压", + "半导体概念", + "氮化镓", + "降解塑料", + "HIT电池", + "转基因", + "流感", + "传感器", + "云游戏", + "MiniLED", + "CRO", + "阿兹海默", + "无线耳机", + "MLCC", + "医疗美容", + "农业种植", + "鸡肉概念", + "智慧政务", + "光刻胶", + "数字货币", + "PCB", + "ETC", + "垃圾分类", + "青蒿素", + "人造肉", + "电子烟", + "氢能源", + "超级真菌", + "数字孪生", + "边缘计算", + "工业大麻", + "纳米银", + "华为概念", + "电子竞技", + "体外诊断", + "知识产权", + "互联医疗", + "工业互联", + "小米概念", + "新零售", + "大飞机", + "雄安新区", + "共享经济", + "精准医疗", + "钛白粉", + "无线充电", + "草甘膦", + "网红直播", + "车联网", + "新能源车", + "国产芯片", + "单抗概念", + "区块链", + "工业4.0", + "人工智能", + "增强现实", + "无人驾驶", + "虚拟现实", + "5G概念", + "一带一路", + "量子通信", + "赛马概念", + "体育产业", + "人脑工程", + "无人机", + "超级电容", + "充电桩", + "全息技术", + "免疫治疗", + "基因测序", + "氟化工", + "燃料电池", + "智能家居", + "超导概念", + "独家药品", + "病毒防治", + "彩票概念", + "国家安防", + "电商概念", + "苹果概念", + "婴童概念", + "智能电视", + "网络安全", + "养老概念", + "特斯拉", + "智能机器", + "智能穿戴", + "互联金融", + "北斗导航", + "通用航空", + "3D打印", + "石墨烯", + "中药概念", + "食品安全", + "风能", + "智能电网", + "太阳能", + "水利建设", + "稀土永磁", + "核能核电", + "锂电池", + "创投", + ] + if return_checked: + return checked_list + + df = Block.query_data( + filters=[Block.category == "concept"], columns=[Block.name], return_type="df", order=Block.timestamp.desc() + ) + return df["name"].tolist() + + +_industry_to_tag_mapping = { + ("风电设备", "电池", "光伏设备", "能源金属", "电源设备"): "新能源", + ("半导体", "电子化学品"): "半导体", + ("医疗服务", "中药", "化学制药", "生物制品", "医药商业"): "医药", + ("医疗器械",): "医疗器械", + ("教育",): "教育", + ("贸易行业", "家用轻工", "造纸印刷", "酿酒行业", "珠宝首饰", "美容护理", "食品饮料", "旅游酒店", "商业百货", "纺织服装", "家电行业"): "大消费", + ("小金属", "贵金属", "有色金属", "煤炭行业"): "资源", + ("消费电子", "电子元件", "光学光电子"): "消费电子", + ("汽车零部件", "汽车服务", "汽车整车"): "汽车", + ("电机", "通用设备", "专用设备", "仪器仪表"): "智能机器", + ("电网设备", "电力行业"): "电力", + ("房地产开发", "房地产服务", "工程建设", "水泥建材", "装修装饰", "装修建材", "工程咨询服务", "钢铁行业", "工程机械"): "房地产", + ("非金属材料", "包装材料", "化学制品", "化肥行业", "化学原料", "化纤行业", "塑料制品", "玻璃玻纤", "橡胶制品"): "化工", + ("交运设备", "船舶制造", "航运港口", "公用事业", "燃气", "航空机场", "环保行业", "石油行业", "铁路公路", "采掘行业"): "公用", + ("证券", "保险", "银行", "多元金融"): "金融", + ("互联网服务", "软件开发", "计算机设备", "游戏", "通信服务", "通信设备"): "AI", + ("文化传媒",): "传媒", + ("农牧饲渔", "农药兽药"): "农业", + ("物流行业",): "物流", + ("航天航空", "船舶制造"): "军工", + ("专业服务",): "专业服务", +} + + +def get_tags_info(): + timestamp = "2024-03-25" + entity_id = "admin" + return [ + { + "id": f"admin_{to_time_str(timestamp)}_{tag}", + "entity_id": entity_id, + "timestamp": timestamp, + "tag": tag, + "tag_reason": f"来自这些行业:{','.join(industries)}", + } + for industries, tag in _industry_to_tag_mapping.items() + ] + + +def get_sub_tags_info(): + timestamp = "2024-03-25" + entity_id = "admin" + return [ + { + "id": f"admin_{to_time_str(timestamp)}_{tag}", + "entity_id": entity_id, + "timestamp": timestamp, + "tag": tag, + "tag_reason": tag, + } + for tag in get_concept_list(return_checked=True) + ] + + +_hidden_tags = { + "中字头": "央企,国资委控股", + "核心资产": "高ROE 高现金流 高股息 低应收 低资本开支 低财务杠杆 有增长", + "高股息": "高股息", + "微盘股": "市值50亿以下", + "次新股": "上市未满两年", +} + + +def get_hidden_tags_info(): + timestamp = "2024-03-25" + entity_id = "admin" + return [ + { + "id": f"admin_{to_time_str(timestamp)}_{tag}", + "entity_id": entity_id, + "timestamp": timestamp, + "tag": tag, + "tag_reason": tag_reason, + } + for tag, tag_reason in _hidden_tags.items() + ] + + +def industry_to_tag(industry): + for key, value in _industry_to_tag_mapping.items(): + if industry in key: + return value + return "未知行业" + + +def check_industry_list_consistency(): + all_industries = get_industry_list() + using_industries = [] + for key in _industry_to_tag_mapping.keys(): + for item in key: + using_industries.append(item) + if sorted(all_industries) == sorted(using_industries): + print(set(all_industries) - set(using_industries)) + print("failed") + assert False + print("ok") + + +if __name__ == "__main__": + check_industry_list_consistency() + print(get_tags_info()) +# the __all__ is generated +__all__ = [ + "get_industry_list", + "get_concept_list", + "get_tags_info", + "get_sub_tags_info", + "get_hidden_tags_info", + "industry_to_tag", + "check_industry_list_consistency", +] diff --git a/src/zvt/tag/tagger.py b/src/zvt/tag/tagger.py new file mode 100644 index 00000000..b7b507fa --- /dev/null +++ b/src/zvt/tag/tagger.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +import logging +from typing import Type + +import pandas as pd + +from zvt.contract import Mixin +from zvt.contract import TradableEntity +from zvt.contract.api import get_db_session +from zvt.contract.base_service import OneStateService, EntityStateService +from zvt.contract.zvt_info import TaggerState +from zvt.domain import Stock, Index +from zvt.tag.dataset.stock_tags import StockTags +from zvt.utils import to_time_str, to_pd_timestamp +from zvt.utils.time_utils import TIME_FORMAT_DAY, now_pd_timestamp + +logger = logging.getLogger(__name__) + + +class Tagger(OneStateService): + state_schema = TaggerState + + entity_schema: Type[TradableEntity] = None + + data_schema: Type[Mixin] = None + + start_timestamp = "2018-01-01" + + def __init__(self, force=False) -> None: + super().__init__() + assert self.entity_schema is not None + assert self.data_schema is not None + self.force = force + self.session = get_db_session(provider="zvt", data_schema=self.data_schema) + if self.state and not self.force: + logger.info(f"get start_timestamp from state") + self.start_timestamp = self.state["current_timestamp"] + logger.info(f"tag start_timestamp: {self.start_timestamp}") + + def tag(self): + raise NotImplementedError + + +class StockTagger(Tagger): + data_schema = StockTags + entity_schema = Stock + + def tag(self): + raise NotImplementedError + + +# the __all__ is generated +__all__ = ["Tagger", "StockTagger"] diff --git a/src/zvt/tag/tags/__init__.py b/src/zvt/tag/tags/__init__.py deleted file mode 100644 index 824b652d..00000000 --- a/src/zvt/tag/tags/__init__.py +++ /dev/null @@ -1,31 +0,0 @@ -# -*- coding: utf-8 -*- -# the __all__ is generated -__all__ = [] - -# __init__.py structure: -# common code of the package -# export interface in __all__ which contains __all__ of its sub modules - -# import all from submodule style_tag -from .style_tag import * -from .style_tag import __all__ as _style_tag_all - -__all__ += _style_tag_all - -# import all from submodule cycle_tag -from .cycle_tag import * -from .cycle_tag import __all__ as _cycle_tag_all - -__all__ += _cycle_tag_all - -# import all from submodule market_value_tag -from .market_value_tag import * -from .market_value_tag import __all__ as _market_value_tag_all - -__all__ += _market_value_tag_all - -# import all from submodule actor_tag -from .actor_tag import * -from .actor_tag import __all__ as _actor_tag_all - -__all__ += _actor_tag_all diff --git a/src/zvt/tag/tags/actor_tag.py b/src/zvt/tag/tags/actor_tag.py deleted file mode 100644 index ec82a092..00000000 --- a/src/zvt/tag/tags/actor_tag.py +++ /dev/null @@ -1,53 +0,0 @@ -# -*- coding: utf-8 -*- -import logging -from enum import Enum - -from zvt.api import get_recent_report -from zvt.contract import ActorType -from zvt.domain import StockActorSummary -from zvt.tag.dataset.stock_tags import StockTags -from zvt.tag.tag import StockTagger -from zvt.utils import pd_is_not_null - -logger = logging.getLogger(__name__) - - -class ActorTag(Enum): - # 基金心头好 - fund_love = "fund_love" - # 基金看不起 - fund_not_care = "fund_not_care" - - -class ActorTagger(StockTagger): - def tag(self, timestamp): - df = get_recent_report( - data_schema=StockActorSummary, - timestamp=timestamp, - filters=[StockActorSummary.actor_type == ActorType.raised_fund.value], - ) - if not pd_is_not_null(df): - logger.error(f"no StockActorSummary data at {timestamp}") - return - - df = df.set_index("entity_id") - - fund_love_ids = df[(df["holding_ratio"] >= 0.05) & (df["change_ratio"] >= -0.3)].index.tolist() - fund_not_care_ids = df[(df["holding_ratio"] < 0.05) | (df["change_ratio"] < -0.3)].index.tolist() - - fund_love_domains = self.get_tag_domains( - entity_ids=fund_love_ids, timestamp=timestamp, actor_tag=ActorTag.fund_love.value - ) - fund_not_care_domains = self.get_tag_domains( - entity_ids=fund_not_care_ids, timestamp=timestamp, actor_tag=ActorTag.fund_not_care.value - ) - self.session.add_all(fund_love_domains) - self.session.add_all(fund_not_care_domains) - self.session.commit() - - -if __name__ == "__main__": - ActorTagger().run() - # print(StockTags.query_data(start_timestamp="2021-08-31", filters=[StockTags.actor_tag != None])) -# the __all__ is generated -__all__ = ["ActorTag", "ActorTagger"] diff --git a/src/zvt/tag/tags/cycle_tag.py b/src/zvt/tag/tags/cycle_tag.py deleted file mode 100644 index 74a4fcd5..00000000 --- a/src/zvt/tag/tags/cycle_tag.py +++ /dev/null @@ -1,138 +0,0 @@ -# -*- coding: utf-8 -*- -from enum import Enum - -from zvt.domain import Stock, BlockStock, Block -from zvt.tag.dataset.stock_tags import StockTags -from zvt.tag.tag import StockTagger - - -class CycleTag(Enum): - # 强周期 - strong_cycle = "strong_cycle" - # 弱周期 - weak_cycle = "weak_cycle" - # 非周期,受周期影响不大,比如消费,医药 - non_cycle = "non_cycle" - - -# 这里用的是东财的行业分类 -# 数据来源 -# Block.record_data(provider='eastmoney') -# Block.query_data(provider='eastmoney', filters=[Block.category==industry]) -# BlockStock.record_data(provider='eastmoney') - -# 这不是一个严格的分类,好像也不需要太严格 -cycle_map_industry = { - CycleTag.strong_cycle: [ - "有色金属", - "水泥建材", - "化工行业", - "化纤行业", - "钢铁行业", - "煤炭采选", - "化肥行业", - "贵金属", - "船舶制造", - "房地产", - "石油行业", - "港口水运", - "材料行业", - ], - CycleTag.weak_cycle: [ - "电力行业", - "民航机场", - "家电行业", - "旅游酒店", - "银行", - "保险", - "高速公路", - "电子信息", - "通讯行业", - "多元金融", - "电子元件", - "国际贸易", - "珠宝首饰", - "交运物流", - "航天航空", - "交运设备", - "汽车行业", - "专用设备", - "园林工程", - "造纸印刷", - "安防设备", - "装修装饰", - "木业家具", - "输配电气", - "工程建设", - "包装材料", - "券商信托", - "机械行业", - "金属制品", - "塑胶制品", - "环保工程", - "玻璃陶瓷", - ], - CycleTag.non_cycle: [ - "食品饮料", - "酿酒行业", - "医疗行业", - "医药制造", - "文教休闲", - "软件服务", - "商业百货", - "文化传媒", - "农牧饲渔", - "公用事业", - "纺织服装", - "综合行业", - "仪器仪表", - "电信运营", - "工艺商品", - "农药兽药", - ], -} - - -def get_cycle_tag(industry_name): - for cycle_tag in cycle_map_industry: - if industry_name in cycle_map_industry.get(cycle_tag): - return cycle_tag.name - - -class CycleTagger(StockTagger): - def tag(self, timestamp): - stock_df = Stock.query_data(filters=[Stock.list_date <= timestamp], index="entity_id") - block_df = Block.query_data(provider="eastmoney", filters=[Block.category == "industry"], index="entity_id") - block_ids = block_df.index.tolist() - block_stock_df = BlockStock.query_data( - provider="eastmoney", - entity_ids=block_ids, - filters=[BlockStock.stock_id.in_(stock_df.index.tolist())], - index="stock_id", - ) - block_stock_df["cycle_tag"] = block_stock_df["name"].apply(lambda name: get_cycle_tag(name)) - strong_cycle_stocks = block_stock_df[block_stock_df.cycle_tag == "strong_cycle"]["stock_id"] - weak_cycle_stocks = block_stock_df[block_stock_df.cycle_tag == "weak_cycle"]["stock_id"] - non_cycle_stocks = block_stock_df[block_stock_df.cycle_tag == "non_cycle"]["stock_id"] - - strong_cycle_domains = self.get_tag_domains( - entity_ids=strong_cycle_stocks, timestamp=timestamp, cycle_tag=CycleTag.strong_cycle.value - ) - weak_cycle_domains = self.get_tag_domains( - entity_ids=weak_cycle_stocks, timestamp=timestamp, cycle_tag=CycleTag.weak_cycle.value - ) - non_cycle_domains = self.get_tag_domains( - entity_ids=non_cycle_stocks, timestamp=timestamp, cycle_tag=CycleTag.non_cycle.value - ) - - self.session.add_all(strong_cycle_domains) - self.session.add_all(weak_cycle_domains) - self.session.add_all(non_cycle_domains) - self.session.commit() - - -if __name__ == "__main__": - CycleTagger().run() - print(StockTags.query_data(start_timestamp="2021-08-31", filters=[StockTags.cycle_tag != None])) -# the __all__ is generated -__all__ = ["CycleTag", "get_cycle_tag", "CycleTagger"] diff --git a/src/zvt/tag/tags/market_value_tag.py b/src/zvt/tag/tags/market_value_tag.py deleted file mode 100644 index 875ad1f3..00000000 --- a/src/zvt/tag/tags/market_value_tag.py +++ /dev/null @@ -1,60 +0,0 @@ -# -*- coding: utf-8 -*- -import logging -from enum import Enum - -from zvt.domain import IndexStock -from zvt.tag.tag import StockTagger -from zvt.utils import pd_is_not_null -from zvt.utils.time_utils import month_start_date, month_end_date - -logger = logging.getLogger(__name__) - - -class MarketValueTag(Enum): - # 市值前500 - huge = "huge" - # 500-1000 - medium = "medium" - # 1000-> - small = "small" - - -index_map_market_value = { - # 399311 国证1000指数由A股市场中市值大、流动性好的1000只股票构成,反映A股市场大中盘股票的价格变动趋势。 - # 399400 大中盘指数由巨潮大盘指数样本股与巨潮中盘指数样本股组合而成,国证1000中的前500 - "index_sz_399400": MarketValueTag.huge, - # 399316 巨潮小盘指数以国证1000指数为样本空间,选取总市值排名在后500名的股票 - "index_sz_399316": MarketValueTag.medium, - # 399303 国证2000由全部A股扣除国证1000指数样本股后,市值大、流动性好的2000只A股组成,反映A股市场小微盘股票的价格变动趋势。 - "index_sz_399303": MarketValueTag.small, -} - - -class MarketValueTagger(StockTagger): - #: a is a number - a = 1 - - def tag(self, timestamp): - for index_id in index_map_market_value: - df = IndexStock.query_data( - entity_id=index_id, start_timestamp=month_start_date(timestamp), end_timestamp=month_end_date(timestamp) - ) - if not pd_is_not_null(df): - logger.error(f"no IndexStock data at {timestamp} for {index_id}") - continue - stock_tags = [ - self.get_tag_domain(entity_id=stock_id, timestamp=timestamp) for stock_id in df["stock_id"].tolist() - ] - - for stock_tag in stock_tags: - stock_tag.market_value_tag = index_map_market_value.get(index_id).value - - self.session.add_all(stock_tags) - self.session.commit() - - -if __name__ == "__main__": - print(MarketValueTagger.__doc__) - # MarketValueTagger().run() -# the __all__ is generated -__all__ = ["MarketValueTag", "MarketValueTagger"] diff --git a/src/zvt/tag/tags/style_tag.py b/src/zvt/tag/tags/style_tag.py deleted file mode 100644 index 11a00dbb..00000000 --- a/src/zvt/tag/tags/style_tag.py +++ /dev/null @@ -1,54 +0,0 @@ -# -*- coding: utf-8 -*- -import logging -from enum import Enum - -from zvt.domain import IndexStock -from zvt.tag.tag import StockTagger - -# define tag -from zvt.utils import pd_is_not_null -from zvt.utils.time_utils import month_start_date, month_end_date - -logger = logging.getLogger(__name__) - - -class StyleTag(Enum): - # 成长股 - growth = "growth" - # 价值股 - value = "value" - - -index_map_style = { - # 国证成长(399370) - # 在国证1000指数样本股中,选取主营业务收入增长率、净利润增长率和净资产收益率综合排名前332只 - "index_sz_399370": StyleTag.growth, - # 国证价值(399371) - # 在国证1000指数样本股中,选取每股收益与价格比率、每股经营现金流与价格比率、股息收益率、每股净资产与价格比率综合排名前332只 - "index_sz_399371": StyleTag.value, -} - - -class StyleTagger(StockTagger): - def tag(self, timestamp): - for index_id in index_map_style: - df = IndexStock.query_data( - entity_id=index_id, start_timestamp=month_start_date(timestamp), end_timestamp=month_end_date(timestamp) - ) - if not pd_is_not_null(df): - logger.error(f"no IndexStock data at {timestamp} for {index_id}") - continue - print(df) - stock_tags = [ - self.get_tag_domain(entity_id=stock_id, timestamp=timestamp) for stock_id in df["stock_id"].tolist() - ] - for stock_tag in stock_tags: - stock_tag.style_tag = index_map_style.get(index_id).value - self.session.add_all(stock_tags) - self.session.commit() - - -if __name__ == "__main__": - StyleTagger().run() -# the __all__ is generated -__all__ = ["StyleTag", "StyleTagger"] diff --git a/src/zvt/utils/model_utils.py b/src/zvt/utils/model_utils.py new file mode 100644 index 00000000..37349f04 --- /dev/null +++ b/src/zvt/utils/model_utils.py @@ -0,0 +1,5 @@ +# -*- coding: utf-8 -*- +def update_model(db_model, schema): + for key, value in schema.dict().items(): + if value is not None: + setattr(db_model, key, value) diff --git a/tests/api/test_intent.py b/tests/api/test_intent.py index 474f1652..97233e5e 100644 --- a/tests/api/test_intent.py +++ b/tests/api/test_intent.py @@ -66,5 +66,9 @@ def test_composite(): def test_composite_all(): composite_all( - entity_ids=None, data_schema=Stock1dKdata, column=Stock1dKdata.turnover, timestamp=to_pd_timestamp("2016-12-02") + provider="joinquant", + entity_ids=None, + data_schema=Stock1dKdata, + column=Stock1dKdata.turnover, + timestamp=to_pd_timestamp("2016-12-02"), ) diff --git a/tests/api/test_kdata.py b/tests/api/test_kdata.py index 135050a5..b0ea3a7a 100644 --- a/tests/api/test_kdata.py +++ b/tests/api/test_kdata.py @@ -42,5 +42,5 @@ def test_jq_1d_hfq_kdata(): def test_get_latest_kdata_date(): - date = get_latest_kdata_date(entity_type="stock", adjust_type=AdjustType.hfq) + date = get_latest_kdata_date(provider="joinquant", entity_type="stock", adjust_type=AdjustType.hfq) assert date is not None diff --git a/tests/contract/test_reader.py b/tests/contract/test_reader.py index 138e6ec2..ba30c699 100644 --- a/tests/contract/test_reader.py +++ b/tests/contract/test_reader.py @@ -15,6 +15,7 @@ def test_china_stock_reader(): data_reader = DataReader( + provider="joinquant", data_schema=Stock1dKdata, entity_schema=Stock, entity_provider="eastmoney", diff --git a/tests/factors/test_factor_select_targets.py b/tests/factors/test_factor_select_targets.py index 7b689776..c047f38e 100644 --- a/tests/factors/test_factor_select_targets.py +++ b/tests/factors/test_factor_select_targets.py @@ -14,6 +14,7 @@ def test_cross_ma_select_targets(): start_timestamp = "2018-01-01" end_timestamp = "2019-06-30" factor = CrossMaFactor( + provider="joinquant", entity_ids=entity_ids, start_timestamp=start_timestamp, end_timestamp=end_timestamp, diff --git a/tests/factors/test_factors.py b/tests/factors/test_factors.py index 640e30be..e02b06e6 100644 --- a/tests/factors/test_factors.py +++ b/tests/factors/test_factors.py @@ -3,8 +3,16 @@ def test_z_factor(): - z = ZFactor(codes=["000338"], need_persist=False) + z = ZFactor( + codes=["000338"], + need_persist=False, + provider="joinquant", + ) z.draw(show=True) - z = ZFactor(codes=["000338", "601318"], need_persist=True) + z = ZFactor( + codes=["000338", "601318"], + need_persist=True, + provider="joinquant", + ) z.draw(show=True) diff --git a/tests/ml/test_sgd.py b/tests/ml/test_sgd.py index ed8b4026..3246a1cd 100644 --- a/tests/ml/test_sgd.py +++ b/tests/ml/test_sgd.py @@ -13,6 +13,7 @@ def test_sgd_classification(): machine = MaStockMLMachine( + data_provider="joinquant", start_timestamp=start_timestamp, end_timestamp=end_timestamp, predict_start_timestamp=predict_start_timestamp, @@ -28,6 +29,7 @@ def test_sgd_classification(): def test_sgd_regressor(): machine = MaStockMLMachine( + data_provider="joinquant", start_timestamp=start_timestamp, end_timestamp=end_timestamp, predict_start_timestamp=predict_start_timestamp, diff --git a/tests/trader/test_trader.py b/tests/trader/test_trader.py index 17fd3439..457324cd 100644 --- a/tests/trader/test_trader.py +++ b/tests/trader/test_trader.py @@ -21,6 +21,7 @@ def long_position_control(self): def test_single_trader(): trader = SingleTrader( + provider="joinquant", codes=["000338"], level=IntervalLevel.LEVEL_1DAY, start_timestamp="2019-01-01", @@ -38,10 +39,18 @@ def test_single_trader(): print(account) buy_price = get_kdata( - entity_id="stock_sz_000338", start_timestamp=buy_timestamp, end_timestamp=buy_timestamp, return_type="domain" + provider="joinquant", + entity_id="stock_sz_000338", + start_timestamp=buy_timestamp, + end_timestamp=buy_timestamp, + return_type="domain", )[0] sell_price = get_kdata( - entity_id="stock_sz_000338", start_timestamp=sell_timestamp, end_timestamp=sell_timestamp, return_type="domain" + provider="joinquant", + entity_id="stock_sz_000338", + start_timestamp=sell_timestamp, + end_timestamp=sell_timestamp, + return_type="domain", )[0] sell_lost = trader.account_service.slippage + trader.account_service.sell_cost @@ -78,6 +87,7 @@ def long_position_control(self): def test_multiple_trader(): trader = MultipleTrader( + provider="joinquant", codes=["000338", "601318"], level=IntervalLevel.LEVEL_1DAY, start_timestamp="2019-01-01", @@ -97,10 +107,18 @@ def test_multiple_trader(): # 000338 buy_price = get_kdata( - entity_id="stock_sz_000338", start_timestamp=buy_timestamp, end_timestamp=buy_timestamp, return_type="domain" + provider="joinquant", + entity_id="stock_sz_000338", + start_timestamp=buy_timestamp, + end_timestamp=buy_timestamp, + return_type="domain", )[0] sell_price = get_kdata( - entity_id="stock_sz_000338", start_timestamp=sell_timestamp, end_timestamp=sell_timestamp, return_type="domain" + provider="joinquant", + entity_id="stock_sz_000338", + start_timestamp=sell_timestamp, + end_timestamp=sell_timestamp, + return_type="domain", )[0] sell_lost = trader.account_service.slippage + trader.account_service.sell_cost @@ -109,10 +127,18 @@ def test_multiple_trader(): # 601318 buy_price = get_kdata( - entity_id="stock_sh_601318", start_timestamp=buy_timestamp, end_timestamp=buy_timestamp, return_type="domain" + provider="joinquant", + entity_id="stock_sh_601318", + start_timestamp=buy_timestamp, + end_timestamp=buy_timestamp, + return_type="domain", )[0] sell_price = get_kdata( - entity_id="stock_sh_601318", start_timestamp=sell_timestamp, end_timestamp=sell_timestamp, return_type="domain" + provider="joinquant", + entity_id="stock_sh_601318", + start_timestamp=sell_timestamp, + end_timestamp=sell_timestamp, + return_type="domain", )[0] pct2 = (sell_price.close * (1 - sell_lost) - buy_price.close * (1 + buy_lost)) / buy_price.close * (1 + buy_lost) @@ -125,6 +151,7 @@ def test_multiple_trader(): def test_basic_trader(): try: MyBullTrader( + provider="joinquant", codes=["000338"], level=IntervalLevel.LEVEL_1DAY, start_timestamp="2018-01-01",