Skip to content

Commit

Permalink
feat: treat categorical color column which is also numeric as categor…
Browse files Browse the repository at this point in the history
…ical in parallel coordinates & parallel categories (#191)

Previously, categorical color columns (e.g. with values `{1,2,3}`, such
as `edvart.example_datasets.dataset_auto()["origin"]`)
would be treated as numeric in multivariate analysis parallel
coordinates & parallel categories sections, i.e. a continuous color
scale would be used.
  • Loading branch information
mbelak-dtml authored Nov 2, 2023
1 parent e3a00f9 commit 8fa399e
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions edvart/report_sections/multivariate_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from IPython.display import Markdown, display
from sklearn.preprocessing import StandardScaler

from edvart.data_types import is_boolean, is_categorical, is_numeric
from edvart.data_types import DataType, infer_data_type, is_boolean, is_categorical, is_numeric
from edvart.plots import scatter_plot_2d
from edvart.report_sections.code_string_formatting import code_dedent, get_code
from edvart.report_sections.section_base import ReportSection, Section, Verbosity
Expand Down Expand Up @@ -620,7 +620,11 @@ def parallel_coordinates(
if drop_na:
df = df.dropna()
if color_col is not None:
is_categorical_color = not is_numeric(df[color_col]) or is_boolean(df[color_col])
is_categorical_color = infer_data_type(df[color_col]) in (
DataType.CATEGORICAL,
DataType.UNIQUE,
DataType.BOOLEAN,
)

if is_categorical_color:
categories = df[color_col].unique()
Expand Down Expand Up @@ -799,7 +803,11 @@ def parallel_categories(
if drop_na:
df = df.dropna()
if color_col is not None:
categorical_color = not is_numeric(df[color_col]) or is_boolean(df[color_col])
categorical_color = infer_data_type(df[color_col]) in (
DataType.CATEGORICAL,
DataType.UNIQUE,
DataType.BOOLEAN,
)
if categorical_color:
categories = df[color_col].unique()
colorscale = list(discrete_colorscale(len(categories)))
Expand Down

0 comments on commit 8fa399e

Please sign in to comment.