Skip to content

Commit

Permalink
Rest
Browse files Browse the repository at this point in the history
  • Loading branch information
foolcage committed Mar 28, 2024
1 parent 2271cda commit 2973082
Show file tree
Hide file tree
Showing 56 changed files with 1,192 additions and 507 deletions.
45 changes: 45 additions & 0 deletions examples/migration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Any, Dict

from pydantic import BaseModel, ConfigDict, Json
from sqlalchemy import Column, String, JSON
from sqlalchemy.orm import declarative_base

from zvt.contract.api import get_db_session
from zvt.contract.register import register_schema
from zvt.contract.schema import Mixin

ZvtInfoBase = declarative_base()


class User(Mixin, ZvtInfoBase):
__tablename__ = "user"
added_col = Column(String)
json_col = Column(JSON)


class UserModel(BaseModel):
model_config = ConfigDict(from_attributes=True)

id: str
entity_id: str
timestamp: datetime
added_col: str
json_col: Dict


register_schema(providers=["zvt"], db_name="test", schema_base=ZvtInfoBase)

if __name__ == "__main__":
user_model = UserModel(
id="user_cn_jack_2020-01-01",
entity_id="user_cn_jack",
timestamp="2020-01-01",
added_col="test",
json_col={"a": 1},
)
session = get_db_session(provider="zvt", data_schema=User)

user = session.query(User).filter(User.id == "user_cn_jack_2020-01-01").first()
print(UserModel.validate(user))
4 changes: 2 additions & 2 deletions examples/ml/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@


def sgd_classification():
machine = MaStockMLMachine(entity_ids=["stock_sz_000001"], label_method="behavior_cls")
machine = MaStockMLMachine(data_provider="em", entity_ids=["stock_sz_000001"], label_method="behavior_cls")
clf = make_pipeline(StandardScaler(), SGDClassifier(max_iter=1000, tol=1e-3))
machine.train(model=clf)
machine.predict()
machine.draw_result(entity_id="stock_sz_000001")


def sgd_regressor():
machine = MaStockMLMachine(entity_ids=["stock_sz_000001"], label_method="raw")
machine = MaStockMLMachine(data_provider="em", entity_ids=["stock_sz_000001"], label_method="raw")
reg = make_pipeline(StandardScaler(), SGDRegressor(max_iter=1000, tol=1e-3))
machine.train(model=reg)
machine.predict()
Expand Down
4 changes: 4 additions & 0 deletions examples/report_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def inform(
entities = get_entities(
provider=entity_provider, entity_type=entity_type, entity_ids=entity_ids, return_type="domain"
)
entities = [entity for entity in entities if entity.entity_id in entity_ids]
print(len(entities))
print(len(entity_ids))
assert len(entities) == len(entity_ids)

if group_by_topic and (entity_type == "stock"):
StockNews.record_data(
Expand Down
2 changes: 1 addition & 1 deletion examples/trader/follow_ii_trader.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def on_time(self, timestamp: pd.Timestamp):

if __name__ == "__main__":
code = "600519"
Stock.record_data(provider="eastmoney")
Stock.record_data(provider="em")
Stock1dKdata.record_data(code=code, provider="em")
StockActorSummary.record_data(code=code, provider="em")
FollowIITrader(
Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
requests==2.31.0
SQLAlchemy==2.0.28
pandas==2.0.3
pydantic==2.6.4
arrow==1.2.3
openpyxl==3.1.1
demjson3==3.0.6
Expand All @@ -11,4 +12,5 @@ dash==2.8.1
jqdatapy==0.1.8
dash-bootstrap-components==1.3.1
dash_daq==0.5.0
scikit-learn==1.2.1
scikit-learn==1.2.1
fastapi==0.110.0
8 changes: 6 additions & 2 deletions src/zvt/api/intent.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,19 @@ def composite_df(df):
drawer.draw_pie(show=True)


def composite_all(data_schema, column, timestamp, entity_ids=None, filters=None):
def composite_all(data_schema, column, timestamp, provider=None, entity_ids=None, filters=None):
if type(column) is not str:
column = column.name
if filters:
filters.append([data_schema.timestamp == to_pd_timestamp(timestamp)])
else:
filters = [data_schema.timestamp == to_pd_timestamp(timestamp)]
df = data_schema.query_data(
entity_ids=entity_ids, columns=["entity_id", "timestamp", column], filters=filters, index="entity_id"
provider=provider,
entity_ids=entity_ids,
columns=["entity_id", "timestamp", column],
filters=filters,
index="entity_id",
)
entity_type, exchange, _ = decode_entity_id(df["entity_id"].iloc[0])
pie_df = pd.DataFrame(columns=df.index, data=[df[column].tolist()])
Expand Down
2 changes: 1 addition & 1 deletion src/zvt/api/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_entity_ids_by_filter(
entity_ids=None,
ignore_bj=False,
):
filters = []
filters = [entity_schema.timestamp.isnot(None)]
if not target_date:
target_date = current_date()
if ignore_new_stock:
Expand Down
15 changes: 14 additions & 1 deletion src/zvt/contract/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
import json
import logging
import os
import platform
Expand Down Expand Up @@ -58,11 +59,17 @@ def get_db_engine(
engine_key = "{}_{}".format(provider, db_name)
db_engine = zvt_context.db_engine_map.get(engine_key)
if not db_engine:
db_engine = create_engine("sqlite:///" + db_path, echo=False)
db_engine = create_engine(
"sqlite:///" + db_path, echo=False, json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False)
)
zvt_context.db_engine_map[engine_key] = db_engine
return db_engine


def get_providers() -> List[str]:
return zvt_context.providers


def get_schemas(provider: str) -> List[DeclarativeMeta]:
"""
get domain schemas supported by the provider
Expand Down Expand Up @@ -99,6 +106,7 @@ def get_db_session(provider: str, db_name: str = None, data_schema: object = Non
return get_db_session_factory(provider, db_name, data_schema)()

session = zvt_context.sessions.get(session_key)
# FIXME: should not maintain global session
if not session:
session = get_db_session_factory(provider, db_name, data_schema)()
zvt_context.sessions[session_key] = session
Expand All @@ -125,6 +133,9 @@ def get_db_session_factory(provider: str, db_name: str = None, data_schema: obje
return session


DBSession = get_db_session_factory


def get_entity_schema(entity_type: str) -> Type[TradableEntity]:
"""
get entity schema from name
Expand Down Expand Up @@ -664,9 +675,11 @@ def get_entity_ids(
__all__ = [
"_get_db_name",
"get_db_engine",
"get_providers",
"get_schemas",
"get_db_session",
"get_db_session_factory",
"DBSession",
"get_entity_schema",
"get_schema_by_name",
"get_schema_columns",
Expand Down
12 changes: 12 additions & 0 deletions src/zvt/contract/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# -*- coding: utf-8 -*-
from datetime import datetime

from pydantic import BaseModel, ConfigDict


class MixinModel(BaseModel):
model_config = ConfigDict(from_attributes=True)

id: str
entity_id: str
timestamp: datetime
2 changes: 1 addition & 1 deletion src/zvt/contract/zvt_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class RecorderState(ZvtInfoBase, StateMixin):
Schema for storing recorder state
"""

__tablename__ = "recoder_state"
__tablename__ = "recorder_state"


class TaggerState(ZvtInfoBase, StateMixin):
Expand Down
12 changes: 12 additions & 0 deletions src/zvt/domain/quotes/stock/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@

__all__ += _stock_30m_hfq_kdata_all

# import all from submodule stock_l2quote_hfq_kdata
from .stock_l2quote_hfq_kdata import *
from .stock_l2quote_hfq_kdata import __all__ as _stock_l2quote_hfq_kdata_all

__all__ += _stock_l2quote_hfq_kdata_all

# import all from submodule stock_1mon_hfq_kdata
from .stock_1mon_hfq_kdata import *
from .stock_1mon_hfq_kdata import __all__ as _stock_1mon_hfq_kdata_all
Expand All @@ -91,6 +97,12 @@

__all__ += _stock_1h_hfq_kdata_all

# import all from submodule stock_l2quote_kdata
from .stock_l2quote_kdata import *
from .stock_l2quote_kdata import __all__ as _stock_l2quote_kdata_all

__all__ += _stock_l2quote_kdata_all

# import all from submodule stock_1m_kdata
from .stock_1m_kdata import *
from .stock_1m_kdata import __all__ as _stock_1m_kdata_all
Expand Down
2 changes: 1 addition & 1 deletion src/zvt/domain/quotes/stock/stock_15m_hfq_kdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Stock15mHfqKdata(KdataBase, StockKdataCommon):


register_schema(
providers=["joinquant", "em"], db_name="stock_15m_hfq_kdata", schema_base=KdataBase, entity_type="stock"
providers=["em", "joinquant"], db_name="stock_15m_hfq_kdata", schema_base=KdataBase, entity_type="stock"
)

# the __all__ is generated
Expand Down
2 changes: 1 addition & 1 deletion src/zvt/domain/quotes/stock/stock_15m_kdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class Stock15mKdata(KdataBase, StockKdataCommon):
__tablename__ = "stock_15m_kdata"


register_schema(providers=["joinquant", "em"], db_name="stock_15m_kdata", schema_base=KdataBase, entity_type="stock")
register_schema(providers=["em", "joinquant"], db_name="stock_15m_kdata", schema_base=KdataBase, entity_type="stock")

# the __all__ is generated
__all__ = ["Stock15mKdata"]
2 changes: 1 addition & 1 deletion src/zvt/domain/quotes/stock/stock_1d_hfq_kdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class Stock1dHfqKdata(KdataBase, StockKdataCommon):
__tablename__ = "stock_1d_hfq_kdata"


register_schema(providers=["joinquant", "em"], db_name="stock_1d_hfq_kdata", schema_base=KdataBase, entity_type="stock")
register_schema(providers=["em", "joinquant"], db_name="stock_1d_hfq_kdata", schema_base=KdataBase, entity_type="stock")

# the __all__ is generated
__all__ = ["Stock1dHfqKdata"]
2 changes: 1 addition & 1 deletion src/zvt/domain/quotes/stock/stock_1d_kdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class Stock1dKdata(KdataBase, StockKdataCommon):
__tablename__ = "stock_1d_kdata"


register_schema(providers=["joinquant", "em"], db_name="stock_1d_kdata", schema_base=KdataBase, entity_type="stock")
register_schema(providers=["em", "joinquant"], db_name="stock_1d_kdata", schema_base=KdataBase, entity_type="stock")

# the __all__ is generated
__all__ = ["Stock1dKdata"]
2 changes: 1 addition & 1 deletion src/zvt/domain/quotes/stock/stock_1h_hfq_kdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class Stock1hHfqKdata(KdataBase, StockKdataCommon):
__tablename__ = "stock_1h_hfq_kdata"


register_schema(providers=["joinquant", "em"], db_name="stock_1h_hfq_kdata", schema_base=KdataBase, entity_type="stock")
register_schema(providers=["em", "joinquant"], db_name="stock_1h_hfq_kdata", schema_base=KdataBase, entity_type="stock")

# the __all__ is generated
__all__ = ["Stock1hHfqKdata"]
2 changes: 1 addition & 1 deletion src/zvt/domain/quotes/stock/stock_1h_kdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class Stock1hKdata(KdataBase, StockKdataCommon):
__tablename__ = "stock_1h_kdata"


register_schema(providers=["joinquant", "em"], db_name="stock_1h_kdata", schema_base=KdataBase, entity_type="stock")
register_schema(providers=["em", "joinquant"], db_name="stock_1h_kdata", schema_base=KdataBase, entity_type="stock")

# the __all__ is generated
__all__ = ["Stock1hKdata"]
2 changes: 1 addition & 1 deletion src/zvt/domain/quotes/stock/stock_1m_hfq_kdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class Stock1mHfqKdata(KdataBase, StockKdataCommon):
__tablename__ = "stock_1m_hfq_kdata"


register_schema(providers=["joinquant", "em"], db_name="stock_1m_hfq_kdata", schema_base=KdataBase, entity_type="stock")
register_schema(providers=["em", "joinquant"], db_name="stock_1m_hfq_kdata", schema_base=KdataBase, entity_type="stock")

# the __all__ is generated
__all__ = ["Stock1mHfqKdata"]
2 changes: 1 addition & 1 deletion src/zvt/domain/quotes/stock/stock_1m_kdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class Stock1mKdata(KdataBase, StockKdataCommon):
__tablename__ = "stock_1m_kdata"


register_schema(providers=["joinquant", "em"], db_name="stock_1m_kdata", schema_base=KdataBase, entity_type="stock")
register_schema(providers=["em", "joinquant"], db_name="stock_1m_kdata", schema_base=KdataBase, entity_type="stock")

# the __all__ is generated
__all__ = ["Stock1mKdata"]
2 changes: 1 addition & 1 deletion src/zvt/domain/quotes/stock/stock_1mon_hfq_kdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Stock1monHfqKdata(KdataBase, StockKdataCommon):


register_schema(
providers=["joinquant", "em"], db_name="stock_1mon_hfq_kdata", schema_base=KdataBase, entity_type="stock"
providers=["em", "joinquant"], db_name="stock_1mon_hfq_kdata", schema_base=KdataBase, entity_type="stock"
)

# the __all__ is generated
Expand Down
2 changes: 1 addition & 1 deletion src/zvt/domain/quotes/stock/stock_1mon_kdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class Stock1monKdata(KdataBase, StockKdataCommon):
__tablename__ = "stock_1mon_kdata"


register_schema(providers=["joinquant", "em"], db_name="stock_1mon_kdata", schema_base=KdataBase, entity_type="stock")
register_schema(providers=["em", "joinquant"], db_name="stock_1mon_kdata", schema_base=KdataBase, entity_type="stock")

# the __all__ is generated
__all__ = ["Stock1monKdata"]
2 changes: 1 addition & 1 deletion src/zvt/domain/quotes/stock/stock_1wk_hfq_kdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Stock1wkHfqKdata(KdataBase, StockKdataCommon):


register_schema(
providers=["joinquant", "em"], db_name="stock_1wk_hfq_kdata", schema_base=KdataBase, entity_type="stock"
providers=["em", "joinquant"], db_name="stock_1wk_hfq_kdata", schema_base=KdataBase, entity_type="stock"
)

# the __all__ is generated
Expand Down
2 changes: 1 addition & 1 deletion src/zvt/domain/quotes/stock/stock_1wk_kdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class Stock1wkKdata(KdataBase, StockKdataCommon):
__tablename__ = "stock_1wk_kdata"


register_schema(providers=["joinquant", "em"], db_name="stock_1wk_kdata", schema_base=KdataBase, entity_type="stock")
register_schema(providers=["em", "joinquant"], db_name="stock_1wk_kdata", schema_base=KdataBase, entity_type="stock")

# the __all__ is generated
__all__ = ["Stock1wkKdata"]
2 changes: 1 addition & 1 deletion src/zvt/domain/quotes/stock/stock_30m_hfq_kdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Stock30mHfqKdata(KdataBase, StockKdataCommon):


register_schema(
providers=["joinquant", "em"], db_name="stock_30m_hfq_kdata", schema_base=KdataBase, entity_type="stock"
providers=["em", "joinquant"], db_name="stock_30m_hfq_kdata", schema_base=KdataBase, entity_type="stock"
)

# the __all__ is generated
Expand Down
2 changes: 1 addition & 1 deletion src/zvt/domain/quotes/stock/stock_30m_kdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class Stock30mKdata(KdataBase, StockKdataCommon):
__tablename__ = "stock_30m_kdata"


register_schema(providers=["joinquant", "em"], db_name="stock_30m_kdata", schema_base=KdataBase, entity_type="stock")
register_schema(providers=["em", "joinquant"], db_name="stock_30m_kdata", schema_base=KdataBase, entity_type="stock")

# the __all__ is generated
__all__ = ["Stock30mKdata"]
2 changes: 1 addition & 1 deletion src/zvt/domain/quotes/stock/stock_4h_hfq_kdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class Stock4hHfqKdata(KdataBase, StockKdataCommon):
__tablename__ = "stock_4h_hfq_kdata"


register_schema(providers=["joinquant", "em"], db_name="stock_4h_hfq_kdata", schema_base=KdataBase, entity_type="stock")
register_schema(providers=["em", "joinquant"], db_name="stock_4h_hfq_kdata", schema_base=KdataBase, entity_type="stock")

# the __all__ is generated
__all__ = ["Stock4hHfqKdata"]
2 changes: 1 addition & 1 deletion src/zvt/domain/quotes/stock/stock_4h_kdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class Stock4hKdata(KdataBase, StockKdataCommon):
__tablename__ = "stock_4h_kdata"


register_schema(providers=["joinquant", "em"], db_name="stock_4h_kdata", schema_base=KdataBase, entity_type="stock")
register_schema(providers=["em", "joinquant"], db_name="stock_4h_kdata", schema_base=KdataBase, entity_type="stock")

# the __all__ is generated
__all__ = ["Stock4hKdata"]
2 changes: 1 addition & 1 deletion src/zvt/domain/quotes/stock/stock_5m_hfq_kdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class Stock5mHfqKdata(KdataBase, StockKdataCommon):
__tablename__ = "stock_5m_hfq_kdata"


register_schema(providers=["joinquant", "em"], db_name="stock_5m_hfq_kdata", schema_base=KdataBase, entity_type="stock")
register_schema(providers=["em", "joinquant"], db_name="stock_5m_hfq_kdata", schema_base=KdataBase, entity_type="stock")

# the __all__ is generated
__all__ = ["Stock5mHfqKdata"]
2 changes: 1 addition & 1 deletion src/zvt/domain/quotes/stock/stock_5m_kdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class Stock5mKdata(KdataBase, StockKdataCommon):
__tablename__ = "stock_5m_kdata"


register_schema(providers=["joinquant", "em"], db_name="stock_5m_kdata", schema_base=KdataBase, entity_type="stock")
register_schema(providers=["em", "joinquant"], db_name="stock_5m_kdata", schema_base=KdataBase, entity_type="stock")

# the __all__ is generated
__all__ = ["Stock5mKdata"]
Loading

0 comments on commit 2973082

Please sign in to comment.