Skip to content

Commit

Permalink
refactor: use mypy type checking (#218)
Browse files Browse the repository at this point in the history
* add proper type hints
* add mypy type checking to CI
  • Loading branch information
mbelak-dtml authored Mar 14, 2024
1 parent e2889d4 commit 5f07d78
Show file tree
Hide file tree
Showing 17 changed files with 103 additions and 61 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ jobs:
poetry run python -m pytest -n auto --disable-warnings --cov=edvart tests/
- name: Lint
run: |
poetry run mypy edvart/
poetry run ruff check .
poetry run ruff format --check .
Expand Down
6 changes: 5 additions & 1 deletion edvart/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def embed_image_base64(image_path: str, mime: str = "image/png") -> str:
# Look up directory where currently executed template is located
# Jinja's @environmentfilter or @contextfilter does not seem to provide
# any information about the path of the template.
template_dir = os.path.dirname(inspect.getfile(inspect.currentframe().f_back))
current_frame = inspect.currentframe()
assert current_frame is not None
frame_back = current_frame.f_back
assert frame_back is not None
template_dir = os.path.dirname(inspect.getfile(frame_back))
with open(os.path.join(template_dir, image_path), "rb") as img:
return f"data:{mime};base64," + str(base64.b64encode(img.read()).decode("utf-8"))
2 changes: 1 addition & 1 deletion edvart/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _scatter_plot_2d_noninteractive(
color_categorical = pd.Categorical(df[color_col])
color_codes = color_categorical.codes
else:
color_codes = df[color_col]
color_codes = df[color_col].values.astype(np.signedinteger)
scatter = ax.scatter(x, y, c=color_codes, alpha=opacity)

if is_color_categorical:
Expand Down
21 changes: 12 additions & 9 deletions edvart/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import sys
import warnings
from abc import ABC
from collections.abc import Sized
from copy import copy
from enum import auto
from typing import List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import isort
import nbconvert
Expand Down Expand Up @@ -83,7 +84,7 @@ def __init__(
self.df = dataframe
self.sections: list[Section] = []
self.verbosity = Verbosity(verbosity)
self._table_of_contents = None
self._table_of_contents: Optional[TableOfContents] = None

def _warn_if_empty(self) -> None:
"""Warns if the report contains no sections."""
Expand Down Expand Up @@ -132,7 +133,7 @@ def _export_data(
return (
code_dedent(
f"""
df_parquet = BytesIO(base64.b85decode({buffer}.decode()))
df_parquet = BytesIO(base64.b85decode({buffer!r}.decode()))
df = pd.read_parquet(df_parquet)"""
),
["import base64", "import pandas as pd", "from io import BytesIO"],
Expand All @@ -143,7 +144,9 @@ def export_notebook(
notebook_filepath: Union[str, os.PathLike],
dataset_name: str = "[INSERT DATASET NAME]",
dataset_description: str = "[INSERT DATASET DESCRIPTION]",
export_data_mode: ExportDataMode = ExportDataMode.NONE,
# mypy assumes that the type of `ExportDataMode.NONE`` is `auto` instead of `ExportDataMode`
# since `auto()` is assigned to it in the enum
export_data_mode: ExportDataMode = ExportDataMode.NONE, # type: ignore
) -> None:
"""Exports the report as an .ipynb file.
Expand Down Expand Up @@ -280,7 +283,7 @@ def _export_html(
Maximum number of seconds to wait for a cell to finish execution.
"""
# Execute notebook to produce output of cells
html_exp_kwargs = dict(
html_exp_kwargs: Dict[str, Any] = dict(
preprocessors=[nbconvert.preprocessors.ExecutePreprocessor(timeout=timeout)]
)
if template_name is not None:
Expand Down Expand Up @@ -348,7 +351,7 @@ def export_html(
# and unpickles the the whole report object from the decoded binary data
unpickle_report = code_dedent(
f"""
data = {buffer_base64}
data = {buffer_base64!r}
report = pickle.loads(base64.b85decode(data), fix_imports=False)
"""
)
Expand Down Expand Up @@ -750,7 +753,7 @@ def __init__(
columns_bivariate_analysis: Optional[List[str]] = None,
columns_multivariate_analysis: Optional[List[str]] = None,
columns_group_analysis: Optional[List[str]] = None,
groupby: Union[str, List[str]] = None,
groupby: Optional[Union[str, List[str]]] = None,
):
super().__init__(dataframe, verbosity)

Expand All @@ -773,7 +776,7 @@ def __init__(
)
if isinstance(groupby, str):
color_col = groupby
elif hasattr(groupby, "__len__") and len(groupby) == 1:
elif isinstance(groupby, Sized) and len(groupby) == 1:
color_col = groupby[0]
else:
color_col = None
Expand Down Expand Up @@ -814,7 +817,7 @@ def __init__(
verbosity: Verbosity = Verbosity.LOW,
):
super().__init__(dataframe, verbosity)
if not is_date(dataframe.index):
if not is_date(dataframe.index.to_series()):
raise ValueError(
"Input dataframe needs to be indexed by time."
"Please reindex your data to be indexed by either a DatetimeIndex or a PeriodIndex."
Expand Down
8 changes: 7 additions & 1 deletion edvart/report_sections/bivariate_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ def __init__(
raise ValueError("Either both or neither of columns_x, columns_y must be specified.")
# For analyses which do not take columns_pairs, prepare columns_x and columns_y in case
# columns_pairs is the only parameter specified
columns_x_no_pairs: Optional[List[str]]
columns_y_no_pairs: Optional[List[str]]
if columns is None and columns_x is None and columns_pairs is not None:
columns_x_no_pairs = [pair[0] for pair in columns_pairs]
columns_y_no_pairs = [pair[1] for pair in columns_pairs]
Expand Down Expand Up @@ -456,6 +458,7 @@ def _get_columns_x_y(
if columns is None:
columns = list(df.columns)
columns_x = columns_y = columns
assert columns_y is not None
columns_x = [col for col in columns_x if is_numeric(df[col])]
columns_y = [col for col in columns_y if is_numeric(df[col])]

Expand Down Expand Up @@ -722,6 +725,7 @@ def include_column(col: str) -> bool:
columns_x = columns
columns_y = columns
if not allow_categorical:
assert columns_y is not None
columns_x = list(filter(include_column, columns_x))
columns_y = list(filter(include_column, columns_y))
sns.pairplot(df, x_vars=columns_x, y_vars=columns_y, hue=color_col)
Expand Down Expand Up @@ -908,6 +912,8 @@ def include_column(col: str) -> bool:
if columns_x is None:
columns_pairs = list(itertools.combinations(columns, 2))
else:
assert columns_x is not None
assert columns_y is not None
columns_pairs = [
(col_x, col_y)
for (col_x, col_y) in itertools.product(columns_x, columns_y)
Expand Down Expand Up @@ -971,7 +977,7 @@ def contingency_table(
annot = table.replace(0, "") if hide_zeros else table

ax = sns.heatmap(
scaling_func(table),
scaling_func(table.values),
annot=annot,
fmt="",
cbar=False,
Expand Down
11 changes: 7 additions & 4 deletions edvart/report_sections/dataset_overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,8 +449,11 @@ def data_types(df: pd.DataFrame, columns: Optional[List[str]] = None) -> None:
"""
if columns is not None:
df = df[columns]
dtypes = df.apply(
func=lambda x_: str(infer_data_type(x_)),

# Type ignored because the apply is not properly typed: the type hints for
# the parameter `func` do not cover the complete set of possible inputs.
dtypes: pd.Series[str] = df.apply(
func=lambda x_: str(infer_data_type(x_)), # type: ignore
axis=0,
result_type="expand",
)
Expand Down Expand Up @@ -652,7 +655,7 @@ def missing_values(
bar_plot_title: str = "Missing Values Percentage of Each Column",
bar_plot_ylim: float = 0,
bar_plot_color: str = "#FFA07A",
**bar_plot_args: Dict[str, Any],
**bar_plot_args: Any,
) -> None:
"""Displays a table of missing values percentages for each column of df and a bar plot
of the percentages.
Expand All @@ -675,7 +678,7 @@ def missing_values(
Bar plot y axis bottom limit.
bar_plot_color : str
Color of bars in the bar plot in hex format.
bar_plot_args : Dict[str, Any]
bar_plot_args : Any
Additional kwargs passed to pandas.Series.bar.
"""
if columns is not None:
Expand Down
21 changes: 13 additions & 8 deletions edvart/report_sections/group_analysis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Callable, Dict, List, Optional, Union
from collections.abc import Hashable
from typing import Any, Callable, Dict, Iterable, List, Optional, Union

import colorlover as cl
import nbformat.v4 as nbfv4
Expand Down Expand Up @@ -102,7 +103,7 @@ def required_imports(self) -> List[str]:
"import plotly.graph_objects as go",
"from edvart.data_types import infer_data_type, DataType",
"from edvart import utils",
"from typing import List, Dict, Optional, Callable",
"from typing import List, Dict, Optional, Callable, Iterable",
"from plotly.subplots import make_subplots",
]

Expand Down Expand Up @@ -218,7 +219,7 @@ def add_cells(self, cells: List[Dict[str, Any]], df: pd.DataFrame) -> None:
)
cells.append(nbfv4.new_code_cell(code))

columns = self.columns if self.columns is not None else df.columns
columns = self.columns if self.columns is not None else df.columns.to_list()

if not self.show_statistics and not self.show_dist:
return
Expand Down Expand Up @@ -362,7 +363,7 @@ def within_group_stats(
df: pd.DataFrame,
groupby: List[str],
column: str,
stats: Dict[str, Callable[[pd.Series], float]] = None,
stats: Optional[Dict[str, Callable[[pd.Series], float]]] = None,
round_decimals: int = 2,
) -> None:
"""Display withing group statistics for a column of df grouped by one or other more columns.
Expand Down Expand Up @@ -448,7 +449,9 @@ def group_missing_values(
df_grouped = df.groupby(groupby)[columns]

# Calculate number of samples in each group
sizes = df_grouped.size().rename("Group Size")
sizes = df_grouped.size()
assert isinstance(sizes, pd.Series)
sizes = sizes.rename("Group Size")

# Calculate missing values percentage of each column for each group
missing = df_grouped.apply(lambda g: g.isna().sum(axis=0))
Expand Down Expand Up @@ -490,7 +493,7 @@ def color_cell(value):
background-color: {bg_hex};
"""

render = final_table.style.applymap(
render = final_table.style.map(
func=color_cell, subset=pd.IndexSlice[:, colored_columns]
).format(formatter="{0:.2f} %", subset=pd.IndexSlice[:, colored_columns])
else:
Expand Down Expand Up @@ -553,7 +556,8 @@ def group_barplot(

fig = go.Figure()
for color_idx, (idx, row) in enumerate(pivot.iterrows()):
if hasattr(idx, "__len__") and not isinstance(idx, str):
group_name: Hashable
if isinstance(idx, Iterable) and not isinstance(idx, str):
group_name = "_".join([str(i) for i in idx])
else:
group_name = idx
Expand Down Expand Up @@ -641,7 +645,8 @@ def overlaid_histograms(
)

for color_idx, (name, group) in enumerate(df.groupby(groupby)):
if hasattr(name, "__len__") and not isinstance(name, str):
group_name: Hashable
if isinstance(name, Iterable) and not isinstance(name, str):
group_name = "_".join([str(i) for i in name])
else:
group_name = name
Expand Down
20 changes: 11 additions & 9 deletions edvart/report_sections/multivariate_analysis.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import IntEnum
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union

import matplotlib.pyplot as plt
import nbformat.v4 as nbfv4
Expand Down Expand Up @@ -487,7 +487,7 @@ def pca_explained_variance(
plt.ylabel("Explained variance ratio")
plt.xticks(
ticks=range(len(pca.explained_variance_ratio_)),
labels=range(1, (len(pca.explained_variance_ratio_) + 1)),
labels=[str(label) for label in range(1, (len(pca.explained_variance_ratio_) + 1))],
)
if show_grid:
plt.grid()
Expand Down Expand Up @@ -630,13 +630,15 @@ def parallel_coordinates(
columns = [col for col in columns if col not in hide_columns]
if drop_na:
df = df.dropna()

line: Optional[Dict[str, Any]] = None
if color_col is not None:
is_categorical_color = infer_data_type(df[color_col]) in (
DataType.CATEGORICAL,
DataType.UNIQUE,
DataType.BOOLEAN,
)

colorscale: Union[List[Tuple[float, str]], str]
if is_categorical_color:
categories = df[color_col].unique()
colorscale = get_default_discrete_colorscale(n_colors=len(categories))
Expand Down Expand Up @@ -669,8 +671,6 @@ def parallel_coordinates(
"cmax": len(categories) - 0.5,
}
)
else:
line = None
# Add numeric columns to dimensions
dimensions = [{"label": col_name, "values": df[col_name]} for col_name in numeric_columns]
# Add categorical columns to dimensions
Expand Down Expand Up @@ -818,12 +818,15 @@ def parallel_categories(
columns = [col for col in columns if col not in hide_columns]
if drop_na:
df = df.dropna()

line: Optional[Dict[str, Any]] = None
if color_col is not None:
categorical_color = infer_data_type(df[color_col]) in (
DataType.CATEGORICAL,
DataType.UNIQUE,
DataType.BOOLEAN,
)
colorscale: Union[List[Tuple[float, str]], str]
if categorical_color:
categories = df[color_col].unique()
colorscale = get_default_discrete_colorscale(n_colors=len(categories))
Expand All @@ -833,14 +836,15 @@ def parallel_categories(
color_series = df[color_col]
colorscale = "Bluered_r"

colorbar: Dict[str, Any] = {"title": color_col}
line = {
"color": color_series,
"colorscale": colorscale,
"colorbar": {"title": color_col},
"colorbar": colorbar,
}

if categorical_color:
line["colorbar"].update(
colorbar.update(
{
"tickvals": color_series.unique(),
"ticktext": categories,
Expand All @@ -855,8 +859,6 @@ def parallel_categories(
"cmax": len(categories) - 0.5,
}
)
else:
line = None

dimensions = [go.parcats.Dimension(values=df[col_name], label=col_name) for col_name in columns]

Expand Down
2 changes: 1 addition & 1 deletion edvart/report_sections/table_of_contents.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def show(self, sections: List[Section]) -> None:
"""
display(Markdown(self._title))

lines = []
lines: List[str] = []
for section in sections:
self._add_section_lines(section, 1, lines, self._include_subsections)
display(Markdown("\n".join(lines)))
Loading

0 comments on commit 5f07d78

Please sign in to comment.