Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some type hints #443

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions pybaseball/amateur_draft.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

import pandas as pd

from . import cache
Expand All @@ -9,7 +11,7 @@
_URL = "https://www.baseball-reference.com/draft/?year_ID={year}&draft_round={draft_round}&draft_type=junreg&query_type=year_round&"


def get_draft_results(year: int, draft_round: int) -> pd.DataFrame:
def get_draft_results(year: int, draft_round: int) -> List[pd.DataFrame]:
url = _URL.format(year=year, draft_round=draft_round)
res = session.get(url, timeout=None).content
draft_results = pd.read_html(res)
Expand All @@ -23,9 +25,9 @@ def amateur_draft(year: int, draft_round: int, keep_stats: bool = True) -> pd.Da

ARGUMENTS
year: The year for which you wish to retrieve draft results.
draft_round: The round for which you wish to retrieve draft results. There is no distinction made
draft_round: The round for which you wish to retrieve draft results. There is no distinction made
between the competitive balance, supplementary, and main portions of a round.
keep_stats: A boolean parameter that controls whether the major league stats of each draftee is
keep_stats: A boolean parameter that controls whether the major league stats of each draftee is
displayed. Default set to true.
"""
draft_results = get_draft_results(year, draft_round)
Expand Down
2 changes: 1 addition & 1 deletion pybaseball/amateur_draft_by_team.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def amateur_draft_by_team(
Get amateur draft results by team and year.

ARGUMENTS
team: Team code which you want to check. See docs for team codes
team: Team code which you want to check. See docs for team codes
(https://github.com/jldbc/pybaseball/blob/master/docs/amateur_draft_by_team.md)
year: Year which you want to check.

Expand Down
4 changes: 2 additions & 2 deletions pybaseball/analysis/projections/marcels/age_adjustment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ def age_adjustment(age: Optional[float]) -> float:
"""
if isnull(age):
return float("nan")

assert age

if age <= 0:
return 1
elif age >= 29:
Expand Down
2 changes: 1 addition & 1 deletion pybaseball/analysis/projections/marcels/marcels_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def metric_projection(self, metric_name: str, projected_season: int) -> pd.DataF
.loc[:, [metric_name]]
)

def projections(self, projected_season: int, computed_metrics: List[str] = None) -> pd.DataFrame:
def projections(self, projected_season: int, computed_metrics: Optional[List[str]] = None) -> pd.DataFrame:
"""
returns projections for all metrics in `computed_metrics`. If
`computed_metrics` is None it uses the default set.
Expand Down
7 changes: 1 addition & 6 deletions pybaseball/batting_leaders.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
import warnings
from typing import Optional

import pandas as pd

from .datasources.fangraphs import fg_batting_data


# This is just a pass through for the new, more configurable function
batting_stats = fg_batting_data
batting_stats = fg_batting_data
2 changes: 1 addition & 1 deletion pybaseball/cache/cache_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class CacheConfig(singleton.Singleton):
CFG_FILENAME = 'cache_config.json'
PYBASEBALL_CACHE_ENV = 'PYBASEBALL_CACHE'

def __init__(self, enabled: bool = False, default_expiration: int = None, cache_type: Optional[str] = None):
def __init__(self, enabled: bool = False, default_expiration: Optional[int] = None, cache_type: Optional[str] = None):
self.enabled = enabled
self.cache_directory = os.environ.get(CacheConfig.PYBASEBALL_CACHE_ENV) or CacheConfig.DEFAULT_CACHE_DIR
self.default_expiration = default_expiration or CacheConfig.DEFAULT_EXPIRATION
Expand Down
2 changes: 1 addition & 1 deletion pybaseball/cache/cache_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


class CacheRecord:
def __init__(self, filename: str = None, data: Optional[Dict[str, Any]] = None,
def __init__(self, filename: Optional[str] = None, data: Optional[Dict[str, Any]] = None,
expires: DateOrNumDays = cache_config.CacheConfig.DEFAULT_EXPIRATION):
''' Create a new cache record. Loads from file if filename is provided, otherwise creates from data, expires '''

Expand Down
6 changes: 3 additions & 3 deletions pybaseball/datahelpers/column_mapper.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from collections import Counter
from typing import Callable, Iterator, List, Optional
from typing import Callable, Iterator, List, Optional, Counter as TypingCounter

ColumnListMapperFunction = Callable[[List[str]], Iterator[str]]

class GenericColumnMapper:
def __init__(self):
self.call_counts = Counter()
def __init__(self) -> None:
self.call_counts: TypingCounter[str] = Counter()

def _short_circuit(self, column_name: str) -> Optional[str]:
return None
Expand Down
2 changes: 1 addition & 1 deletion pybaseball/datahelpers/singleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ def __new__(cls: Type['Singleton'], *args: Any, **kwargs: Any) -> 'Singleton':
cls.__INSTANCE__ = super(Singleton, cls).__new__(cls)

assert cls.__INSTANCE__ is not None
return cls.__INSTANCE__
return cls.__INSTANCE__
3 changes: 1 addition & 2 deletions pybaseball/datasources/bref.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, max_requests_per_minute: int = 10) -> None:
self.max_requests_per_minute = max_requests_per_minute
self.last_request: Optional[datetime.datetime] = None
self.session = requests.Session()

def get(self, url: str, **kwargs: Any) -> requests.Response:
if self.last_request:
delta = datetime.datetime.now() - self.last_request
Expand All @@ -32,4 +32,3 @@ def get(self, url: str, **kwargs: Any) -> requests.Response:
self.last_request = datetime.datetime.now()

return self.session.get(url, **kwargs)

6 changes: 3 additions & 3 deletions pybaseball/datasources/fangraphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ class FangraphsBattingStatsTable(FangraphsDataTable):
ROW_ID_NAME = 'IDfg'

@cache.df_cache()
def fetch(self, *args, **kwargs):
def fetch(self, *args: Any, **kwargs: Any) -> pd.DataFrame:
return super().fetch(*args, **kwargs)

def _postprocess(self, data: pd.DataFrame) -> pd.DataFrame:
Expand All @@ -186,7 +186,7 @@ class FangraphsFieldingStatsTable(FangraphsDataTable):
ROW_ID_NAME = 'IDfg'

@cache.df_cache()
def fetch(self, *args, **kwargs):
def fetch(self, *args: Any, **kwargs: Any) -> pd.DataFrame:
return super().fetch(*args, **kwargs)

def _postprocess(self, data: pd.DataFrame) -> pd.DataFrame:
Expand All @@ -199,7 +199,7 @@ class FangraphsPitchingStatsTable(FangraphsDataTable):
ROW_ID_NAME = 'IDfg'

@cache.df_cache()
def fetch(self, *args, **kwargs):
def fetch(self, *args: Any, **kwargs: Any) -> pd.DataFrame:
return super().fetch(*args, **kwargs)

def _postprocess(self, data: pd.DataFrame) -> pd.DataFrame:
Expand Down
12 changes: 6 additions & 6 deletions pybaseball/datasources/html_table_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

class HTMLTableProcessor:
def __init__(self, root_url: str, headings_xpath: str, data_rows_xpath: str, data_cell_xpath: str,
table_class: str = None):
table_class: Optional[str] = None):
self.root_url = root_url
self.table_class = table_class
self.headings_xpath = headings_xpath.format(TABLE_XPATH=self.table_xpath)
Expand All @@ -25,7 +25,7 @@ def table_xpath(self) -> str:
return f'//table[@class="{self.table_class}"]'
return '//table'

def get_tabular_data_from_element(self, element: lxml.etree.Element, column_name_mapper: ColumnListMapperFunction = None,
def get_tabular_data_from_element(self, element: lxml.etree.Element, column_name_mapper: Optional[ColumnListMapperFunction] = None,
known_percentages: Optional[List[str]] = None, row_id_func: RowIdFunction = None,
row_id_name: Optional[str] = None) -> pd.DataFrame:
headings = element.xpath(self.headings_xpath)
Expand All @@ -51,7 +51,7 @@ def get_tabular_data_from_element(self, element: lxml.etree.Element, column_name

return fg_data

def get_tabular_data_from_html(self, html: Union[str, bytes], column_name_mapper: ColumnListMapperFunction = None,
def get_tabular_data_from_html(self, html: Union[str, bytes], column_name_mapper: Optional[ColumnListMapperFunction] = None,
known_percentages: Optional[List[str]] = None, row_id_func: RowIdFunction = None,
row_id_name: Optional[str] = None) -> pd.DataFrame:
html_dom = lxml.etree.HTML(html)
Expand All @@ -64,8 +64,8 @@ def get_tabular_data_from_html(self, html: Union[str, bytes], column_name_mapper
row_id_name=row_id_name,
)

def get_tabular_data_from_url(self, url: str, query_params: Dict[str, Union[str, int]] = None,
column_name_mapper: ColumnListMapperFunction = None,
def get_tabular_data_from_url(self, url: str, query_params: Optional[Dict[str, Union[str, int]]] = None,
column_name_mapper: Optional[ColumnListMapperFunction] = None,
known_percentages: Optional[List[str]] = None, row_id_func: RowIdFunction = None,
row_id_name: Optional[str] = None) -> pd.DataFrame:
response = requests.get(self.root_url + url, params=query_params)
Expand All @@ -84,7 +84,7 @@ def get_tabular_data_from_url(self, url: str, query_params: Dict[str, Union[str,
)

def get_tabular_data_from_options(self, base_url: str, query_params: Dict[str, Union[str, int]],
column_name_mapper: ColumnListMapperFunction = None,
column_name_mapper: Optional[ColumnListMapperFunction] = None,
known_percentages: Optional[List[str]] = None, row_id_func: RowIdFunction = None,
row_id_name: Optional[str] = None) -> pd.DataFrame:
return self.get_tabular_data_from_url(
Expand Down
1 change: 0 additions & 1 deletion pybaseball/datasources/statcast.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import io
import os
from datetime import datetime
from typing import List, Union

Expand Down
2 changes: 1 addition & 1 deletion pybaseball/enums/fangraphs/league.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ class FangraphsLeague(EnumBase):
EWL = 'ewl' # East-West League
NSL = 'nsl' # Negro Southern League
NN2 = 'nn2' # Negro National League II
NAL = "nal" # Negro American League
NAL = "nal" # Negro American League
2 changes: 1 addition & 1 deletion pybaseball/lahman.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_lahman_zip() -> Optional[ZipFile]:
_handle = ZipFile(BytesIO(s.content))
return _handle

def download_lahman():
def download_lahman() -> None:
# download entire lahman db to present working directory
z = get_lahman_zip()
if z is not None:
Expand Down
5 changes: 0 additions & 5 deletions pybaseball/pitching_leaders.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
import warnings
from typing import Optional

import pandas as pd

from .datasources.fangraphs import fg_pitching_data


Expand Down
18 changes: 9 additions & 9 deletions pybaseball/playerid_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
import zipfile

from typing import List, Tuple, Iterable
from typing import List, Tuple, Iterable, Optional

import pandas as pd
import requests
Expand All @@ -18,7 +18,7 @@
_client = None


def get_register_file():
def get_register_file() -> str:
return os.path.join(cache.config.cache_directory, 'chadwick-register.csv')


Expand Down Expand Up @@ -72,15 +72,15 @@ def chadwick_register(save: bool = False) -> pd.DataFrame:
return table


def get_lookup_table(save=False):
def get_lookup_table(save: bool = False) -> pd.DataFrame:
table = chadwick_register(save)
# make these lowercase to avoid capitalization mistakes when searching
table['name_last'] = table['name_last'].str.lower()
table['name_first'] = table['name_first'].str.lower()
return table


def get_closest_names(last: str, first: str, player_table: pd.DataFrame) -> pd.DataFrame:
def get_closest_names(last: str, first: Optional[str], player_table: pd.DataFrame) -> pd.DataFrame:
"""Calculates similarity of first and last name provided with all players in player_table

Args:
Expand All @@ -102,7 +102,7 @@ class _PlayerSearchClient:
def __init__(self) -> None:
self.table = get_lookup_table()

def search(self, last: str, first: str = None, fuzzy: bool = False, ignore_accents: bool = False) -> pd.DataFrame:
def search(self, last: str, first: Optional[str] = None, fuzzy: bool = False, ignore_accents: bool = False) -> pd.DataFrame:
"""Lookup playerIDs (MLB AM, bbref, retrosheet, FG) for a given player

Args:
Expand Down Expand Up @@ -137,7 +137,7 @@ def search(self, last: str, first: str = None, fuzzy: bool = False, ignore_accen
if len(results) == 0 and fuzzy:
print("No identically matched names found! Returning the 5 most similar names.")
results=get_closest_names(last=last, first=first, player_table=self.table)

return results


Expand All @@ -150,12 +150,12 @@ def search_list(self, player_list: List[Tuple[str, str]]) -> pd.DataFrame:

Returns:
pd.DataFrame: DataFrame of playerIDs, name, years played
'''
'''
results = pd.DataFrame()

for last, first in player_list:
results = results.append(self.search(last, first), ignore_index=True)

return results


Expand Down Expand Up @@ -217,7 +217,7 @@ def player_search_list(player_list: List[Tuple[str, str]]) -> pd.DataFrame:

Returns:
pd.DataFrame: DataFrame of playerIDs, name, years played
'''
'''
client = _get_client()
return client.search_list(player_list)

Expand Down
Loading