Skip to content

Commit

Permalink
Adds ProjectionVisualizer base class (#908)
Browse files Browse the repository at this point in the history
This is the first major step toward completing #874: the implementation of a ProjectionVisualizer base class to unify functionality of decomposition visualizers that use PCA and Manifold and to extend support to other decomposition methods. In a follow up PR, we will reorganize this class and extend the functionality in Manifold and PCA.
  • Loading branch information
naresh-bachwani authored and bbengfort committed Jul 17, 2019
1 parent 6608c49 commit 21eb9d2
Show file tree
Hide file tree
Showing 13 changed files with 564 additions and 16 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
12 changes: 3 additions & 9 deletions tests/test_features/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
##########################################################################

import pytest
from unittest.mock import patch

from yellowbrick.base import Visualizer
from yellowbrick.features.base import *
Expand Down Expand Up @@ -80,8 +79,7 @@ def test_subclass(self):
@pytest.mark.usefixtures("discrete", "continuous")
class TestDataVisualizerBase(VisualTestCase):

@patch.object(DataVisualizer, 'draw')
def test_single(self, mock_draw):
def test_single(self):

dataviz = DataVisualizer()
# Check default is auto
Expand All @@ -95,7 +93,6 @@ def test_single(self, mock_draw):
dataviz = DataVisualizer(target_type="continuous")
X, y = self.continuous
dataviz.fit(X)
mock_draw.assert_called_once()
assert dataviz._colors == 'b'
assert dataviz._target_color_type == TargetType.SINGLE

Expand All @@ -104,13 +101,11 @@ def test_single(self, mock_draw):
dataviz._determine_target_color_type(None)
assert dataviz._target_color_type == TargetType.SINGLE

@patch.object(DataVisualizer, 'draw')
def test_continuous(self, mock_draw):
def test_continuous(self):
# Check when y is continuous
X, y = self.continuous
dataviz = DataVisualizer()
dataviz.fit(X, y)
mock_draw.assert_called_once()
assert hasattr(dataviz, "range_")
assert dataviz._target_color_type == TargetType.CONTINUOUS

Expand Down Expand Up @@ -140,8 +135,7 @@ def test_bad_target(self):
with pytest.raises(YellowbrickValueError, match=msg):
DataVisualizer(target_type="foo")

@patch.object(DataVisualizer, 'draw')
def test_classes(self, mock_draw):
def test_classes(self):
# Checks that classes are assigned correctly
X, y = self.discrete
classes = ['a', 'b', 'c', 'd', 'e']
Expand Down
243 changes: 243 additions & 0 deletions tests/test_features/test_projection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
# tests.test_features.test_projection
# Test the base ProjectionVisualizer drawing functionality
#
# Author: Naresh Bachwani
# Created: Wed Jul 17 09:53:07 2019 -0400
#
# Copyright (C) 2019, the scikit-yb developers.
# For license information, see LICENSE.txt
#
# ID: test_projection.py [] [email protected] $

"""
Test the base ProjectionVisualizer drawing functionality
"""

##########################################################################
## Imports
##########################################################################


import pytest
import matplotlib.pyplot as plt

from yellowbrick.features.projection import *
from yellowbrick.exceptions import YellowbrickValueError

from tests.base import VisualTestCase
from ..fixtures import Dataset
from unittest import mock

from sklearn.decomposition import PCA
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import make_classification, make_regression


##########################################################################
## Fixtures
##########################################################################

@pytest.fixture(scope="class")
def discrete(request):
"""
Creare a random classification fixture.
"""
X, y = make_classification(
n_samples=400, n_features=12, n_informative=10, n_redundant=0,
n_classes=5, random_state=2019)

# Set a class attribute for discrete data
request.cls.discrete = Dataset(X, y)


@pytest.fixture(scope="class")
def continuous(request):
"""
Creates a random regressor fixture.
"""
X, y = make_regression(
n_samples=500, n_features=22, n_informative=8, random_state=2019
)

# Set a class attribute for continuous data
request.cls.continuous = Dataset(X, y)


##########################################################################
## MockVisualizer
##########################################################################

class MockVisualizer(ProjectionVisualizer):
"""
The MockVisualizer implements the ProjectionVisualizer interface using
PCA as an internal transformer. This visualizer is used to directly test
how subclasses interact with the ProjectionVisualizer base class.
"""

def __init__(self, ax=None, features=None, classes=None, color=None,
colormap=None, target_type="auto", projection=2,
alpha=0.75,**kwargs):

super(MockVisualizer, self).__init__(ax=ax,
features=features, classes=classes,
color=color, colormap=colormap,
target_type=target_type,
projection=projection, alpha=alpha,
**kwargs)

self.pca_transformer = Pipeline([("scale", StandardScaler()),
("pca", PCA(self.projection, random_state=2019))])

def fit(self, X, y=None):
super(MockVisualizer, self).fit(X, y)
self.pca_transformer.fit(X)
return self

def transform(self, X, y=None):
try:
Xp = self.pca_transformer.transform(X)
except AttributeError as e:
raise AttributeError(str(e) + " try using fit_transform instead.")
self.draw(Xp, y)
return Xp


##########################################################################
## ProjectionVisualizer Tests
##########################################################################

@pytest.mark.usefixtures("discrete", "continuous")
class TestProjectionVisualizer(VisualTestCase):
"""
Test the ProjectionVisualizer base class
"""

def test_discrete_plot(self):
"""
Test the visualizer with discrete target.
"""
X, y = self.discrete
classes = ["a", "b", "c", "d", "e"]
visualizer = MockVisualizer(projection=2, colormap="plasma", classes=classes)
X_prime = visualizer.fit_transform(X, y)
assert(visualizer.classes_ == classes)
visualizer.finalize()
self.assert_images_similar(visualizer)
assert X_prime.shape == (self.discrete.X.shape[0], 2)

def test_continuous_plot(self):
"""
Test the visualizer with continuous target.
"""
X, y = self.continuous
visualizer = MockVisualizer(projection="2d")
visualizer.fit_transform(X, y)
visualizer.finalize()
visualizer.cax.set_yticklabels([])
self.assert_images_similar(visualizer)

def test_continuous_when_target_discrete(self):
"""
Ensure user can override discrete target_type by specifying continuous
"""
_, ax = plt.subplots()
X, y = self.discrete
visualizer = MockVisualizer(ax=ax, projection="2D",
target_type="continuous", colormap="cool")
visualizer.fit(X, y)
visualizer.transform(X, y)
visualizer.finalize()
visualizer.cax.set_yticklabels([])
self.assert_images_similar(visualizer)

def test_single_plot(self):
"""
Assert single color plot when y is not specified
"""
X, y = self.discrete
visualizer = MockVisualizer(projection=2,
colormap="plasma")
visualizer.fit_transform(X)
visualizer.finalize()
self.assert_images_similar(visualizer)

def test_discrete_3d(self):
"""
Test visualizer for 3 dimensional discrete plots
"""
X, y = self.discrete

classes = ["a", "b", "c", "d", "e"]
color = ["r", "b", "g", "m","c"]
visualizer = MockVisualizer(projection=3,
color=color, classes=classes)
visualizer.fit_transform(X, y)
assert visualizer.classes_ == classes
visualizer.finalize()
self.assert_images_similar(visualizer)

def test_3d_continuous_plot(self):
"""
Tests visualizer for 3 dimensional continuous plots
"""
X, y = self.continuous
visualizer = MockVisualizer(projection="3D")
visualizer.fit_transform(X, y)
visualizer.finalize()
visualizer.cbar.set_ticks([])
self.assert_images_similar(visualizer)

def test_alpha_param(self):
"""
Ensure that the alpha parameter modifies opacity
"""
# Instantiate a prediction error plot, provide custom alpha
X, y = self.discrete
params = {"alpha": 0.3, "projection": 2}
visualizer = MockVisualizer(**params)
visualizer.ax = mock.MagicMock()
visualizer.fit(X, y)
visualizer.transform(X, y)

assert visualizer.alpha == 0.3

# Test that alpha was passed to internal matplotlib scatterplot
_, scatter_kwargs = visualizer.ax.scatter.call_args
assert "alpha" in scatter_kwargs
assert scatter_kwargs["alpha"] == 0.3

# Check Errors
@pytest.mark.parametrize("projection", ["4D", 1, "100d", 0])
def test_wrong_projection_dimensions(self, projection):
"""
Validate projection hyperparameter
"""
msg = "Projection dimensions must be either 2 or 3"
with pytest.raises(YellowbrickValueError, match=msg):
MockVisualizer(projection=projection)

def test_target_not_label_encoded(self):
"""
Assert label encoding mismatch with y raises exception
"""
X, y = self.discrete
# Multiply every element by 10 to make non-label encoded
y = y*10
visualizer = MockVisualizer()
msg = "Target needs to be label encoded."
with pytest.raises(YellowbrickValueError, match = msg):
visualizer.fit_transform(X, y)

@pytest.mark.parametrize("dataset", ("discrete", "continuous"))
def test_y_required_for_discrete_and_continuous(self, dataset):
"""
Assert error is raised when y is not passed to transform
"""
X, y = getattr(self, dataset)
visualizer = MockVisualizer()
visualizer.fit(X, y)

msg = "y is required for {} target".format(dataset)
with pytest.raises(YellowbrickValueError, match = msg):
visualizer.transform(X)
4 changes: 2 additions & 2 deletions yellowbrick/contrib/missing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def fit(self, X, y=None, **kwargs):

self.y = y

super(MissingDataVisualizer, self).fit(X, y, **kwargs)

self.draw(X, y, **kwargs)
return self

def get_feature_names(self):
if self.features_ is None:
Expand Down
4 changes: 1 addition & 3 deletions yellowbrick/features/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,8 @@ def fit(self, X, y=None, **kwargs):
raise YellowbrickValueError(
"unknown target color type '{}'".format(self._target_color_type))

# Draw the instances
self.draw(X, y, **kwargs)

# Fit always returns self.
# NOTE: cannot call draw in fit to support data transformers
return self

def _determine_target_color_type(self, y):
Expand Down
5 changes: 3 additions & 2 deletions yellowbrick/features/pcoords.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,9 @@ def fit(self, X, y=None, **kwargs):
if self.normalize is not None:
X = self.NORMALIZERS[self.normalize].fit_transform(X)

# the super method calls draw and returns self
return super(ParallelCoordinates, self).fit(X, y, **kwargs)
self.draw(X, y, **kwargs)

return self

def draw(self, X, y, **kwargs):
"""
Expand Down
Loading

0 comments on commit 21eb9d2

Please sign in to comment.