Skip to content

Commit

Permalink
Updated DataVisualizer to handle target_type_identification (#893)
Browse files Browse the repository at this point in the history
Updates the DataVisualizer to perform target type identification as implemented in Manifold. This was an original requirement of the DataVisualizer but remained unimplemented since ParallelCoordinates and RadViz were the only main library subclasses. This is the first step in the ProjectionVisualizer high-dimensional visualization base class.

Related to #874
  • Loading branch information
naresh-bachwani authored and bbengfort committed Jul 2, 2019
1 parent 07e7597 commit 39a4a02
Show file tree
Hide file tree
Showing 7 changed files with 213 additions and 9 deletions.
2 changes: 1 addition & 1 deletion tests/test_cluster/test_elbow.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def test_distortion_metric(self):
assert len(visualizer.k_scores_) == 4

visualizer.finalize()
self.assert_images_similar(visualizer)
self.assert_images_similar(visualizer, tol=0.03)
assert_array_almost_equal(visualizer.k_scores_, expected)

@pytest.mark.xfail(
Expand Down
116 changes: 114 additions & 2 deletions tests/test_features/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,44 @@
## Imports
##########################################################################

import pytest
from unittest.mock import patch

from yellowbrick.base import Visualizer
from yellowbrick.features.base import FeatureVisualizer
from yellowbrick.features.base import *
from tests.base import VisualTestCase
from ..fixtures import TestDataset

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.datasets import make_classification, make_regression

##########################################################################
## Fixtures
##########################################################################

@pytest.fixture(scope='class')
def discrete(request):
"""
Creare a random classification fixture.
"""
X, y = make_classification(
n_samples=400, n_features=12, n_informative=10, n_redundant=0,
n_classes=5, random_state=2019)

# Set a class attribute for digits
request.cls.discrete = TestDataset(X, y)

@pytest.fixture(scope='class')
def continuous(request):
"""
Creates a random regressor fixture.
"""
X, y = make_regression(
n_samples=500, n_features=22, n_informative=8, random_state=2019
)

# Set a class attribute for digits
request.cls.continuous = TestDataset(X, y)


##########################################################################
Expand All @@ -30,11 +63,90 @@

class TestFeatureVisualizerBase(VisualTestCase):

def test_subclass(self):
def test_subclass(self):
"""
Assert the feature visualizer is in its rightful place
"""
visualizer = FeatureVisualizer()
assert isinstance(visualizer, TransformerMixin)
assert isinstance(visualizer, BaseEstimator)
assert isinstance(visualizer, Visualizer)


##########################################################################
## DataVisualizer Tests
##########################################################################

@pytest.mark.usefixtures("discrete", "continuous")
class TestDataVisualizerBase(VisualTestCase):

@patch.object(DataVisualizer, 'draw')
def test_single(self, mock_draw):

dataviz = DataVisualizer()
# Check default is auto
assert dataviz.target_type == TargetType.AUTO

# Assert single when y is None
dataviz._determine_target_color_type(None)
assert dataviz._target_color_type == TargetType.SINGLE

# None overrides specified target
dataviz = DataVisualizer(target_type="continuous")
X, y = self.continuous
dataviz.fit(X)
mock_draw.assert_called_once()
assert dataviz._colors == 'b'
assert dataviz._target_color_type == TargetType.SINGLE

# None overrides specified target
dataviz = DataVisualizer(target_type="discrete")
dataviz._determine_target_color_type(None)
assert dataviz._target_color_type == TargetType.SINGLE

@patch.object(DataVisualizer, 'draw')
def test_continuous(self, mock_draw):
# Check when y is continuous
X, y = self.continuous
dataviz = DataVisualizer()
dataviz.fit(X, y)
mock_draw.assert_called_once()
assert hasattr(dataviz, "range_")
assert dataviz._target_color_type == TargetType.CONTINUOUS

# Check when default is set to continuous and discrete data passed in
dataviz = DataVisualizer(target_type="continuous")
X, y = self.discrete
dataviz._determine_target_color_type(y)
assert dataviz._target_color_type == TargetType.CONTINUOUS

def test_discrete(self):
# Check when y is discrete
_, y = self.discrete
dataviz = DataVisualizer()
dataviz._determine_target_color_type(y)
assert dataviz._target_color_type == TargetType.DISCRETE

# Check when default is set to discrete and continuous data passed in
_, y = self.continuous
dataviz = DataVisualizer(target_type="discrete")
dataviz._determine_target_color_type(y)
assert dataviz._target_color_type == TargetType.DISCRETE

def test_bad_target(self):
# Bad target raises exception
# None overrides specified target
msg = "unknown target color type 'foo'"
with pytest.raises(YellowbrickValueError, match=msg):
DataVisualizer(target_type="foo")

@patch.object(DataVisualizer, 'draw')
def test_classes(self, mock_draw):
# Checks that classes are assigned correctly
X, y = self.discrete
classes = ['a', 'b', 'c', 'd', 'e']
dataviz = DataVisualizer(classes=classes, target_type='discrete')
dataviz.fit(X, y)
assert dataviz.classes_ == classes
assert list(dataviz._colors.keys()) == classes

3 changes: 3 additions & 0 deletions yellowbrick/contrib/missing/bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ class MissingValuesBar(MissingDataVisualizer):
"""

def __init__(self, width=0.5, color='black', colors=None, classes=None, **kwargs):

if "target_type" not in kwargs:
kwargs["target_type"] = "single"
super(MissingValuesBar, self).__init__(**kwargs)
self.width = width # the width of the bars
self.classes_ = classes
Expand Down
2 changes: 2 additions & 0 deletions yellowbrick/contrib/missing/dispersion.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class MissingValuesDispersion(MissingDataVisualizer):

def __init__(self, alpha=0.5, marker="|", classes=None, **kwargs):

if "target_type" not in kwargs:
kwargs["target_type"] = "single"
super(MissingValuesDispersion, self).__init__(**kwargs)
self.alpha = alpha
self.marker = marker
Expand Down
94 changes: 89 additions & 5 deletions yellowbrick/features/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,15 @@
##########################################################################

import numpy as np
import matplotlib as mpl
import warnings
from enum import Enum

from yellowbrick.base import Visualizer
from yellowbrick.utils import is_dataframe
from yellowbrick.style import resolve_colors
from yellowbrick.exceptions import YellowbrickValueError, YellowbrickWarning

from sklearn.base import TransformerMixin


Expand Down Expand Up @@ -120,10 +126,18 @@ def fit(self, X, y=None, **fit_params):

return self


##########################################################################
## Data Visualizers
##########################################################################

class TargetType(Enum):
AUTO = 'auto'
SINGLE = 'single'
DISCRETE = 'discrete'
CONTINUOUS = 'continuous'


class DataVisualizer(MultiFeatureVisualizer):
"""
Data Visualizers are a subclass of Feature Visualizers which plot the
Expand Down Expand Up @@ -166,6 +180,16 @@ class DataVisualizer(MultiFeatureVisualizer):
Use either color to colorize the lines on a per class basis or
colormap to color them on a continuous scale.
target_type : str, default: "auto"
Specify the type of target as either "discrete" (classes) or "continuous"
(real numbers, usually for regression). If "auto", then it will
attempt to determine the type by counting the number of unique values.
If the target is discrete, the colors are returned as a dict with classes
being the keys. If continuous the colors will be list having value of
color for each point. In either case, if no target is specified, then
color will be specified as blue.
kwargs : dict
Keyword arguments that are passed to the base class and may influence
the visualization as defined in other Visualizers.
Expand All @@ -177,7 +201,7 @@ class DataVisualizer(MultiFeatureVisualizer):
"""

def __init__(self, ax=None, features=None, classes=None, color=None,
colormap=None, **kwargs):
colormap=None, target_type="auto", **kwargs):
"""
Initialize the data visualization with many of the options required
in order to make most visualizations work.
Expand All @@ -190,6 +214,11 @@ def __init__(self, ax=None, features=None, classes=None, color=None,
# Visual Parameters
self.color = color
self.colormap = colormap
try:
# Ensures that target is either Single, Discrete, Continuous or Auto
self.target_type = TargetType(target_type)
except ValueError:
raise YellowbrickValueError("unknown target color type '{}'".format(target_type))

def fit(self, X, y=None, **kwargs):
"""
Expand All @@ -215,13 +244,68 @@ def fit(self, X, y=None, **kwargs):
"""
super(DataVisualizer, self).fit(X, y, **kwargs)

# Store the classes for the legend if they're None.
if self.classes_ is None:
# TODO: Is this the most efficient method?
self.classes_ = [str(label) for label in np.unique(y)]
self._determine_target_color_type(y)

if self._target_color_type == TargetType.SINGLE:
self._colors = 'b'

# Compute classes and colors if target type is discrete
elif self._target_color_type == TargetType.DISCRETE:
# Store the classes for the legend if they're None.
if self.classes_ is None:
# TODO: Is this the most efficient method?
self.classes_ = [str(label) for label in np.unique(y)]

# Ensures that classes passed by user is equal to that in target
if len(self.classes_)!=len(np.unique(y)):
warnings.warn(("Number of unique target is not "
"equal to classes"), YellowbrickWarning)

color_values = resolve_colors(n_colors=len(self.classes_),
colormap=self.colormap, colors=self.color)
self._colors = dict(zip(self.classes_, color_values))

# Compute target range if colors are continuous
elif self._target_color_type == TargetType.CONTINUOUS:
y = np.asarray(y)
self.range_ = (y.min(), y.max())

self._colors = mpl.cm.get_cmap(self.colormap)

else:
raise YellowbrickValueError(
"unknown target color type '{}'".format(self._target_color_type))

# Draw the instances
self.draw(X, y, **kwargs)

# Fit always returns self.
return self

def _determine_target_color_type(self, y):
"""
Determines the target color type from the vector y as follows:
- if y is None: only a single color is used
- if target is auto: determine if y is continuous or discrete
- otherwise specify supplied target type
This property will be used to compute the colors for each point.
"""
if y is None:
self._target_color_type = TargetType.SINGLE
elif self.target_type == TargetType.AUTO:
# NOTE: See #73 for a generalization to use when implemented
if len(np.unique(y)) < 10:
self._target_color_type = TargetType.DISCRETE
else:
self._target_color_type = TargetType.CONTINUOUS
else:
self._target_color_type = self.target_type

# Ensures that target is either SINGLE, DISCRETE or CONTINUOS and not AUTO
if self._target_color_type == TargetType.AUTO:
raise YellowbrickValueError((
"could not determine target color type "
"from target='{}' to '{}'"
).format(self.target_type, self._target_color_type))
3 changes: 2 additions & 1 deletion yellowbrick/features/pcoords.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,8 @@ def __init__(self,
vlines=True,
vlines_kwds=None,
**kwargs):

if "target_type" not in kwargs:
kwargs["target_type"] = "discrete"
super(ParallelCoordinates, self).__init__(
ax, features, classes, color, colormap, **kwargs
)
Expand Down
2 changes: 2 additions & 0 deletions yellowbrick/features/radviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ class RadialVisualizer(DataVisualizer):

def __init__(self, ax=None, features=None, classes=None, color=None,
colormap=None, alpha=1.0, **kwargs):
if "target_type" not in kwargs:
kwargs["target_type"] = "discrete"
super(RadialVisualizer, self).__init__(
ax, features, classes, color, colormap, **kwargs
)
Expand Down

0 comments on commit 39a4a02

Please sign in to comment.