Skip to content

Commit

Permalink
Move some common feature from examples to main package
Browse files Browse the repository at this point in the history
  • Loading branch information
foolcage committed Jun 5, 2024
1 parent a129db8 commit ec1a373
Show file tree
Hide file tree
Showing 21 changed files with 137 additions and 82 deletions.
6 changes: 1 addition & 5 deletions examples/README.MD
Original file line number Diff line number Diff line change
@@ -1,5 +1 @@
The examples is updating with master source code.
Run with specific version e.g. v0.9.16 need to
```
git checkout v0.10.1
```
The examples are updating with master source code.
7 changes: 4 additions & 3 deletions examples/query_snippet.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def get_sub_tag_stocks(tag="低空经济"):


if __name__ == "__main__":
a = get_block_stocks()
b = get_sub_tag_stocks()
print(set(a) - set(b))
# a = get_block_stocks()
# b = get_sub_tag_stocks()
# print(set(a) - set(b))
print(Stock.query_data(provider="em", return_type="dict"))
3 changes: 2 additions & 1 deletion examples/report_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Type

from examples.tag_utils import group_stocks_by_tag, get_main_line_tags, get_main_line_hidden_tags
from examples.utils import add_to_eastmoney, msg_group_stocks_by_topic
from examples.utils import msg_group_stocks_by_topic
from zvt import zvt_config
from zvt.api import get_top_volume_entities, TopType
from zvt.api.kdata import get_latest_kdata_date, get_kdata_schema, default_adjust_type
Expand All @@ -16,6 +16,7 @@
from zvt.domain import StockNews
from zvt.informer import EmailInformer
from zvt.utils import date_time_by_interval
from zvt.informer.inform_utils import add_to_eastmoney

logger = logging.getLogger("__name__")

Expand Down
2 changes: 1 addition & 1 deletion examples/reports/report_core_compay.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

from examples.factors.fundamental_selector import FundamentalSelector
from examples.reports import get_subscriber_emails, stocks_with_info
from examples.utils import add_to_eastmoney
from zvt import init_log, zvt_config
from zvt.contract.api import get_entities
from zvt.domain import Stock
from zvt.factors.target_selector import TargetSelector
from zvt.informer.informer import EmailInformer
from zvt.informer.inform_utils import add_to_eastmoney
from zvt.utils.time_utils import now_pd_timestamp, to_time_str

logger = logging.getLogger(__name__)
Expand Down
1 change: 0 additions & 1 deletion examples/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
zvt >= 0.10.1
apscheduler >= 3.4.0
eastmoneypy == 0.1.5
tabulate>=0.8.8
ta
34 changes: 1 addition & 33 deletions examples/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,15 @@
import json
import logging
import os
import time

import eastmoneypy
import pandas as pd
import requests

from zvt.domain import StockNews, Stock, LimitUpInfo
from zvt.utils import date_time_by_interval, today

logger = logging.getLogger(__name__)


def add_to_eastmoney(codes, group, entity_type="stock", over_write=True):
with requests.Session() as session:
codes = list(set(codes))
if over_write:
try:
eastmoneypy.del_group(group_name=group, session=session)
except:
pass
try:
eastmoneypy.create_group(group_name=group, session=session)
except:
pass

group_id = eastmoneypy.get_group_id(group, session=session)

for code in codes:
eastmoneypy.add_to_group(code=code, entity_type=entity_type, group_id=group_id, session=session)


def clean_groups(keep=None):
if keep is None:
keep = ["自选股", "练气", "重要板块", "主线"]

with requests.Session() as session:
groups = eastmoneypy.get_groups(session=session)
groups_to_clean = [group["gid"] for group in groups if group["gname"] not in keep]
for gid in groups_to_clean:
eastmoneypy.del_group(group_id=gid, session=session)


def get_hot_words_config():
with open(os.path.join(os.path.dirname(__file__), "hot.json")) as f:
return json.load(f)
Expand Down
1 change: 0 additions & 1 deletion examples/zhdate/__init__.py

This file was deleted.

3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ dash_daq==0.5.0
scikit-learn==1.2.1
fastapi==0.110.0
fastapi-pagination==0.12.23
apscheduler==3.10.4
apscheduler==3.10.4
eastmoneypy==0.1.7
22 changes: 10 additions & 12 deletions src/zvt/broker/qmt/qmt_quote.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,21 @@
# -*- coding: utf-8 -*-
import json
import logging
import time
import datetime

import numpy as np
import pandas as pd
from xtquant import xtdata
from sqlalchemy import text
from zvt.api import get_recent_report_date
from zvt.common.query_models import TimeUnit

from zvt.contract import IntervalLevel, AdjustType
from zvt.contract.api import decode_entity_id, df_to_db, get_db_engine, get_schema_columns, get_db_session
from zvt.contract.api import decode_entity_id, df_to_db, get_db_session
from zvt.domain import StockQuote, Stock
from zvt.utils import (
to_time_str,
current_date,
to_pd_timestamp,
now_time_str,
date_time_by_interval,
pd_is_not_null,
TIME_FORMAT_DAY,
TIME_FORMAT_MINUTE2,
)

from sqlalchemy.dialects.sqlite import insert

# https://dict.thinktrader.net/nativeApi/start_now.html?id=e2M5nZ

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -247,7 +238,14 @@ def download_capital_data():
)


def clear_history_quote():
session = get_db_session("qmt", data_schema=StockQuote)
session.query(StockQuote).filter(StockQuote.timestamp < current_date()).delete()
session.commit()


if __name__ == "__main__":
clear_history_quote()
Stock.record_data(provider="em")
stocks = get_qmt_stocks()
print(stocks)
Expand Down
9 changes: 8 additions & 1 deletion src/zvt/contract/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,13 @@ def get_by_id(data_schema, id: str, provider: str = None, session: Session = Non
return session.query(data_schema).get(id)


def _row2dict(row):
d = {}
for column in row.__table__.columns:
d[column.name] = getattr(row, column.name)
return d


def get_data(
data_schema: Type[Mixin],
ids: List[str] = None,
Expand Down Expand Up @@ -376,7 +383,7 @@ def get_data(
return query.all()
elif return_type == "dict":
domains = query.all()
return [item.__dict__ for item in domains]
return [_row2dict(item) for item in domains]
elif return_type == "select":
return query.selectable

Expand Down
6 changes: 3 additions & 3 deletions src/zvt/fill_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,12 @@ def gen_kdata_schemas():
# zip_dir(ZVT_TEST_DATA_PATH, zip_file_name=DATA_SAMPLE_ZIP_PATH)
# gen_exports("contract", export_modules=["schema"])
# gen_exports("ml")
# gen_exports("utils")
gen_exports("utils", export_var=True)
# gen_exports('informer')
# gen_exports("api")
# gen_exports('trader')
# gen_exports('autocode')
# gen_exports("ml")
gen_kdata_schemas()
gen_exports("zhdate")
# gen_kdata_schemas()
# gen_exports("recorders")
# gen_exports("tag")
54 changes: 54 additions & 0 deletions src/zvt/informer/inform_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
import eastmoneypy

from zvt import zvt_config
from zvt.contract.api import get_entities
from zvt.informer import EmailInformer

import requests


def inform_email(entity_ids, entity_type, target_date, title, provider):
msg = "no targets"
if entity_ids:
entities = get_entities(provider=provider, entity_type=entity_type, entity_ids=entity_ids, return_type="domain")
assert len(entities) == len(entity_ids)

infos = [f"{entity.name}({entity.code})" for entity in entities]
msg = "\n".join(infos) + "\n"

EmailInformer().send_message(zvt_config["email_username"], f"{target_date} {title}", msg)


def add_to_eastmoney(codes, group, entity_type="stock", over_write=True):
with requests.Session() as session:
codes = list(set(codes))
if over_write:
try:
eastmoneypy.del_group(group_name=group, session=session)
except:
pass
try:
eastmoneypy.create_group(group_name=group, session=session)
except:
pass

group_id = eastmoneypy.get_group_id(group, session=session)

for code in codes:
eastmoneypy.add_to_group(code=code, entity_type=entity_type, group_id=group_id, session=session)


def clean_groups(keep):
if keep is None:
keep = ["自选股", "练气", "重要板块", "主线"]

with requests.Session() as session:
groups = eastmoneypy.get_groups(session=session)
groups_to_clean = [group["gid"] for group in groups if group["gname"] not in keep]
for gid in groups_to_clean:
eastmoneypy.del_group(group_id=gid, session=session)


# the __all__ is generated
__all__ = ["inform_email", "add_to_eastmoney", "clean_groups"]
6 changes: 6 additions & 0 deletions src/zvt/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
# common code of the package
# export interface in __all__ which contains __all__ of its sub modules

# import all from submodule model_utils
from .model_utils import *
from .model_utils import __all__ as _model_utils_all

__all__ += _model_utils_all

# import all from submodule str_utils
from .str_utils import *
from .str_utils import __all__ as _str_utils_all
Expand Down
16 changes: 0 additions & 16 deletions src/zvt/utils/inform_utils.py

This file was deleted.

4 changes: 4 additions & 0 deletions src/zvt/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@ def update_model(db_model, schema):
for key, value in schema.dict().items():
if value is not None:
setattr(db_model, key, value)


# the __all__ is generated
__all__ = ["update_model"]
4 changes: 3 additions & 1 deletion src/zvt/utils/time_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,14 +277,17 @@ def count_interval(start_date, end_date):
"TIME_FORMAT_DAY",
"TIME_FORMAT_DAY1",
"TIME_FORMAT_MINUTE",
"TIME_FORMAT_SECOND",
"TIME_FORMAT_MINUTE1",
"TIME_FORMAT_MINUTE2",
"to_pd_timestamp",
"get_local_timezone",
"to_timestamp",
"now_timestamp",
"now_pd_timestamp",
"today",
"current_date",
"tomorrow_date",
"to_time_str",
"now_time_str",
"date_time_by_interval",
Expand All @@ -306,5 +309,4 @@ def count_interval(start_date, end_date):
"is_in_same_interval",
"split_time_interval",
"count_interval",
"tomorrow_date",
]
3 changes: 3 additions & 0 deletions src/zvt/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,9 @@ def fill_dict(src, dst):
"iterate_with_step",
"url_unquote",
"parse_url_params",
"set_one_and_only_one",
"flatten_list",
"to_str",
"compare_dicts",
"fill_dict",
]
25 changes: 25 additions & 0 deletions src/zvt/zhdate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# -*- coding: utf-8 -*-
# the __all__ is generated
__all__ = []

# __init__.py structure:
# common code of the package
# export interface in __all__ which contains __all__ of its sub modules

# import all from submodule constants
from .constants import *
from .constants import __all__ as _constants_all

__all__ += _constants_all

# import all from submodule ztime
from .ztime import *
from .ztime import __all__ as _ztime_all

__all__ += _ztime_all

# import all from submodule zhdate
from .zhdate import *
from .zhdate import __all__ as _zhdate_all

__all__ += _zhdate_all
2 changes: 2 additions & 0 deletions examples/zhdate/constants.py → src/zvt/zhdate/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,3 +422,5 @@
"""
从1900年,至2100年每年的农历春节的公历日期
"""
# the __all__ is generated
__all__ = []
6 changes: 5 additions & 1 deletion examples/zhdate/zhdate.py → src/zvt/zhdate/zhdate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from datetime import datetime, timedelta
from itertools import accumulate

from examples.zhdate.constants import CHINESEYEARCODE, CHINESENEWYEAR
from zvt.zhdate.constants import CHINESEYEARCODE, CHINESENEWYEAR


class ZhDate:
Expand Down Expand Up @@ -255,3 +255,7 @@ def month_days(year):
[int] -- 农历年份所对应的农历月份天数列表
"""
return ZhDate.decode(CHINESEYEARCODE[year - 1900])


# the __all__ is generated
__all__ = []
5 changes: 3 additions & 2 deletions examples/ztime.py → src/zvt/zhdate/ztime.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-
from examples.zhdate.zhdate import ZhDate

from zvt.utils import to_pd_timestamp, current_date, count_interval
from zvt.zhdate.zhdate import ZhDate


def holiday_distance(timestamp=None, days=15):
Expand Down Expand Up @@ -65,3 +64,5 @@ def holiday_distance(timestamp=None, days=15):
# for month in range(1, 13):
# holiday_distance(f"2023-{month}-15")
# holiday_distance(f"2023-{month}-20")
# the __all__ is generated
__all__ = ["holiday_distance"]

0 comments on commit ec1a373

Please sign in to comment.