Skip to content

Commit

Permalink
feat: add type schema (#1274)
Browse files Browse the repository at this point in the history
* feat: allows user to define variable types
  • Loading branch information
alexbarros authored Mar 3, 2023
1 parent b9ada64 commit 79856bc
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 5 deletions.
4 changes: 4 additions & 0 deletions src/ydata_profiling/compare_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ def _compare_profile_report_preprocess(
config.html.style.primary_colors
)

# enforce same types
for report in reports[1:]:
report._typeset = reports[0].typeset

# Obtain description sets
descriptions = [report.get_description() for report in reports]
for label, description in zip(labels, descriptions):
Expand Down
11 changes: 9 additions & 2 deletions src/ydata_profiling/model/pandas/summary_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ydata_profiling.config import Settings
from ydata_profiling.model.summarizer import BaseSummarizer
from ydata_profiling.model.summary import describe_1d, get_series_descriptions
from ydata_profiling.model.typeset import ProfilingTypeSet
from ydata_profiling.utils.dataframe import sort_column_names


Expand All @@ -37,8 +38,13 @@ def pandas_describe_1d(
# Make sure pd.NA is not in the series
series = series.fillna(np.nan)

# get `infer_dtypes` (bool) from config
if config.infer_dtypes:
if (
isinstance(typeset, ProfilingTypeSet)
and typeset.type_schema
and series.name in typeset.type_schema
):
vtype = typeset.type_schema[series.name]
elif config.infer_dtypes:
# Infer variable types
vtype = typeset.infer_type(series)
series = typeset.cast_to_inferred(series)
Expand All @@ -47,6 +53,7 @@ def pandas_describe_1d(
# [new dtypes, changed using `astype` function are now considered]
vtype = typeset.detect_type(series)

typeset.type_schema[series.name] = vtype
return summarizer.summarize(config, series, dtype=vtype)


Expand Down
13 changes: 12 additions & 1 deletion src/ydata_profiling/model/typeset.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,22 @@ def is_timedependent(series: pd.Series) -> bool:


class ProfilingTypeSet(visions.VisionsTypeset):
def __init__(self, config: Settings):
def __init__(self, config: Settings, type_schema: dict = None):
self.config = config

types = typeset_types(config)

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
super().__init__(types)

self.type_schema = self._init_type_schema(type_schema or {})

def _init_type_schema(self, type_schema: dict) -> dict:
return {k: self._get_type(v) for k, v in type_schema.items()}

def _get_type(self, type_name: str) -> visions.VisionsBaseType:
for t in self.types:
if t.__name__.lower() == type_name.lower():
return t
raise ValueError(f"Type [{type_name}] not found.")
5 changes: 4 additions & 1 deletion src/ydata_profiling/profile_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
typeset: Optional[VisionsTypeset] = None,
summarizer: Optional[BaseSummarizer] = None,
config: Optional[Settings] = None,
type_schema: Optional[dict] = None,
**kwargs,
):
"""Generate a ProfileReport based on a pandas or spark.sql DataFrame
Expand All @@ -89,6 +90,7 @@ def __init__(
sample: optional dict(name="Sample title", caption="Caption", data=pd.DataFrame())
typeset: optional user typeset to use for type inference
summarizer: optional user summarizer to generate custom summary output
type_schema: optional dict containing pairs of `column name`: `type`
**kwargs: other arguments, for valid arguments, check the default configuration file.
"""
self.__validate_inputs(df, minimal, tsmode, config_file, lazy)
Expand Down Expand Up @@ -139,6 +141,7 @@ def __init__(
self.config = report_config
self._df_hash = None
self._sample = sample
self._type_schema = type_schema
self._typeset = typeset
self._summarizer = summarizer

Expand Down Expand Up @@ -230,7 +233,7 @@ def invalidate_cache(self, subset: Optional[str] = None) -> None:
@property
def typeset(self) -> Optional[VisionsTypeset]:
if self._typeset is None:
self._typeset = ProfilingTypeSet(self.config)
self._typeset = ProfilingTypeSet(self.config, self._type_schema)
return self._typeset

@property
Expand Down
37 changes: 36 additions & 1 deletion tests/unit/test_typeset_default.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os

import numpy as np
import pandas as pd
import pytest
from visions.test.series import get_series
from visions.test.utils import (
Expand All @@ -14,6 +16,7 @@
from tests.unit.test_utils import patch_arg
from ydata_profiling.config import Settings
from ydata_profiling.model.typeset import ProfilingTypeSet
from ydata_profiling.profile_report import ProfileReport

base_path = os.path.abspath(os.path.dirname(__file__))

Expand Down Expand Up @@ -161,7 +164,7 @@
)
)
def test_contains(name, series, contains_type, member):
"""Test the generated combinations for "series in type"
"""Test the generated combinations for "series in type".
Args:
series: the series to test
Expand Down Expand Up @@ -349,3 +352,35 @@ def test_conversion(name, source_type, relation_type, series, member):
"""
result, message = convert(name, source_type, relation_type, series, member)
assert result, message


@pytest.fixture
def dataframe(size: int = 1000) -> pd.DataFrame:
return pd.DataFrame(
{
"boolean": np.random.choice([True, False], size=size),
"numeric": np.random.rand(size),
"categorical": np.random.choice(np.arange(5), size=size),
"timeseries": np.arange(size),
}
)


def convertion_map() -> list:
types = {
"boolean": ["Categorical", "Unsupported"],
"numeric": ["Categorical", "Boolean", "Unsupported"],
"categorical": ["Numeric", "Boolean", "TimeSeries", "Unsupported"],
"timeseries": ["Numeric", "Boolean", "Categorical", "Unsupported"],
}
return [(k, {k: i}) for k, v in types.items() for i in v]


@pytest.mark.parametrize("column,type_schema", convertion_map())
def test_type_schema(dataframe: pd.DataFrame, column: str, type_schema: dict):
prof = ProfileReport(dataframe[[column]], tsmode=True, type_schema=type_schema)
prof.get_description()
assert isinstance(prof.typeset, ProfilingTypeSet)
assert prof.typeset.type_schema[column] == prof.typeset._get_type(
type_schema[column]
)

0 comments on commit 79856bc

Please sign in to comment.