Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FreeViz: Allow setting ratio btw attractive and repulsive forces #6515

Merged
merged 1 commit into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions Orange/projection/freeviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class FreeViz(LinearProjector):
projection = FreeVizModel

def __init__(self, weights=None, center=True, scale=True, dim=2, p=1,
initial=None, maxiter=500, alpha=0.1,
initial=None, maxiter=500, alpha=0.1, gravity=None,
atol=1e-5, preprocessors=None):
super().__init__(preprocessors=preprocessors)
self.weights = weights
Expand All @@ -33,6 +33,7 @@ def __init__(self, weights=None, center=True, scale=True, dim=2, p=1,
self.maxiter = maxiter
self.alpha = alpha
self.atol = atol
self.gravity = gravity
self.is_class_discrete = False
self.components_ = None

Expand All @@ -50,6 +51,7 @@ def get_components(self, X, Y):
X, Y, weights=self.weights, center=self.center, scale=self.scale,
dim=self.dim, p=self.p, initial=self.initial,
maxiter=self.maxiter, alpha=self.alpha, atol=self.atol,
gravity=self.gravity,
is_class_discrete=self.is_class_discrete)[1].T

@classmethod
Expand Down Expand Up @@ -104,7 +106,7 @@ def forces_regression(cls, distances, y, p=1):
return F

@classmethod
def forces_classification(cls, distances, y, p=1):
def forces_classification(cls, distances, y, p=1, gravity=None):
diffclass = scipy.spatial.distance.pdist(y.reshape(-1, 1), "hamming") != 0
# handle attractive force
if p == 1:
Expand All @@ -120,6 +122,8 @@ def forces_classification(cls, distances, y, p=1):
F[mask] = 1 / distances[mask]
else:
F[mask] = 1 / (distances[mask] ** p)
if gravity is not None:
F[mask] *= -np.sum(F[~mask]) / np.sum(F[mask]) / gravity
return F

@classmethod
Expand Down Expand Up @@ -180,7 +184,8 @@ def gradient(cls, X, embeddings, forces, embedding_dist=None, weights=None):
return G

@classmethod
def freeviz_gradient(cls, X, y, embedding, p=1, weights=None, is_class_discrete=False):
def freeviz_gradient(cls, X, y, embedding, p=1, weights=None,
gravity=None, is_class_discrete=False):
"""
Return the gradient for the FreeViz [1]_ projection.

Expand Down Expand Up @@ -214,7 +219,7 @@ def freeviz_gradient(cls, X, y, embedding, p=1, weights=None, is_class_discrete=
assert X.ndim == 2 and X.shape[0] == y.shape[0] == embedding.shape[0]
D = scipy.spatial.distance.pdist(embedding)
if is_class_discrete:
forces = cls.forces_classification(D, y, p=p)
forces = cls.forces_classification(D, y, p=p, gravity=gravity)
else:
forces = cls.forces_regression(D, y, p=p)
G = cls.gradient(X, embedding, forces, embedding_dist=D, weights=weights)
Expand All @@ -234,7 +239,8 @@ def _rotate(cls, A):

@classmethod
def freeviz(cls, X, y, weights=None, center=True, scale=True, dim=2, p=1,
initial=None, maxiter=500, alpha=0.1, atol=1e-5, is_class_discrete=False):
initial=None, maxiter=500, alpha=0.1, atol=1e-5, gravity=None,
is_class_discrete=False):
"""
FreeViz

Expand Down Expand Up @@ -341,6 +347,7 @@ def freeviz(cls, X, y, weights=None, center=True, scale=True, dim=2, p=1,
step_i = 0
while step_i < maxiter:
G = cls.freeviz_gradient(X, y, embeddings, p=p, weights=weights,
gravity=gravity,
is_class_discrete=is_class_discrete)

# Scale the changes (the largest anchor move is alpha * radius)
Expand Down
44 changes: 42 additions & 2 deletions Orange/widgets/visualize/owfreeviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np

from AnyQt.QtCore import Qt, QRectF, QLineF, QPoint
from AnyQt.QtGui import QPalette
from AnyQt.QtGui import QPalette, QFontMetrics
from AnyQt.QtWidgets import QSizePolicy

import pyqtgraph as pg
Expand Down Expand Up @@ -137,9 +137,13 @@ class OWFreeViz(OWAnchorProjectionWidget, ConcurrentWidgetMixin):

settings_version = 3
initialization = settings.Setting(InitType.Circular)
balance = settings.Setting(False)
gravity_index = settings.Setting(4)
GRAPH_CLASS = OWFreeVizGraph
graph = settings.SettingProvider(OWFreeVizGraph)

GravityValues = [0.1, 0.25, 0.5, 0.75, 1, 1.25, 1.5, 1.75, 2, 2.5, 3, 4, 5]

class Error(OWAnchorProjectionWidget.Error):
no_class_var = widget.Msg("Data must have a target variable.")
multiple_class_vars = widget.Msg(
Expand All @@ -159,6 +163,7 @@ class Warning(OWAnchorProjectionWidget.Warning):
def __init__(self):
OWAnchorProjectionWidget.__init__(self)
ConcurrentWidgetMixin.__init__(self)
self.__optimized = False
VesnaT marked this conversation as resolved.
Show resolved Hide resolved

def _add_controls(self):
self.__add_controls_start_box()
Expand All @@ -177,6 +182,20 @@ def __add_controls_start_box(self):
callback=self.__init_combo_changed,
sizePolicy=(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed)
)
box2 = gui.hBox(box)
gui.checkBox(
box2, self, "balance", "Gravity",
callback=self.__gravity_changed)
self.grav_slider = gui.hSlider(
box2, self, "gravity_index",
minValue=0, maxValue=len(self.GravityValues) - 1,
callback=self.__gravity_dragged, createLabel=False)
self.gravity_label = gui.widgetLabel(box2)
self.gravity_label.setFixedWidth(
max(QFontMetrics(self.font()).horizontalAdvance(str(x))
for x in self.GravityValues))
self.gravity_label.setAlignment(Qt.AlignRight)
self.__update_gravity_label()
self.run_button = gui.button(box, self, "Start", self._toggle_run)

@property
Expand All @@ -189,6 +208,21 @@ def effective_data(self):
return self.data.transform(Domain(self.effective_variables,
self.data.domain.class_vars))

def __gravity_dragged(self):
self.balance = True
self.__gravity_changed()

def __update_gravity_label(self):
self.gravity_label.setText(str(self.GravityValues[self.gravity_index]))

def __gravity_changed(self):
gravity = self.GravityValues[self.gravity_index]
if self.projector is not None:
self.projector.gravity = gravity if self.balance else None
self.__update_gravity_label()
if self.task is None and self.__optimized:
self._run()

def __radius_slider_changed(self):
self.graph.update_radius()

Expand Down Expand Up @@ -232,6 +266,7 @@ def on_done(self, result: Result):
self.projection = result.projection
self.graph.set_sample_size(None)
self.run_button.setText("Start")
self.__optimized = True
self.commit.deferred()

def on_exception(self, ex: Exception):
Expand All @@ -253,14 +288,19 @@ def init_projection(self):
anchors = FreeViz.init_radial(len(self.effective_variables)) \
if self.initialization == InitType.Circular \
else FreeViz.init_random(len(self.effective_variables), 2)
if self.balance:
gravity = self.GravityValues[self.gravity_index]
else:
gravity = None
self.projector = FreeViz(scale=False, center=False,
initial=anchors, maxiter=10)
initial=anchors, maxiter=10, gravity=gravity)
data = self.projector.preprocess(self.effective_data)
self.projector.domain = data.domain
self.projector.components_ = anchors.T
self.projection = FreeVizModel(self.projector, self.projector.domain, 2)
self.projection.pre_domain = data.domain
self.projection.name = self.projector.name
self.__optimized = False

def check_data(self):
def error(err):
Expand Down
42 changes: 41 additions & 1 deletion Orange/widgets/visualize/tests/test_owfreeviz.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Test methods with long descriptive names can omit docstrings
# pylint: disable=missing-docstring
import unittest
from unittest.mock import Mock
from unittest.mock import Mock, patch

import numpy as np

Expand Down Expand Up @@ -156,6 +156,46 @@ def test_discrete_attributes(self):
self.assertTrue(self.widget.Warning.removed_features.is_shown())
self.widget.run_button.click()

def test_gravity_slider(self):
w = self.widget

w.balance = False
w.gravity_index = 0

w.grav_slider.setValue(2)
self.assertTrue(w.balance)
self.assertEqual(w.gravity_label.text(), str(w.GravityValues[2]))

w.grav_slider.setValue(3)
self.assertTrue(w.balance)
self.assertEqual(w.gravity_label.text(), str(w.GravityValues[3]))

assert w.projector is None
self.send_signal(self.widget.Inputs.data, Table("zoo"))
self.wait_until_finished()
assert w.projector is not None

# w.projector.gravity has correct value if gravity was set before data
self.assertEqual(w.projector.gravity, w.GravityValues[3])

# ... and if set when the data is already present and projector exists
w.grav_slider.setValue(1)
self.assertEqual(w.projector.gravity, w.GravityValues[1])

# Check that optimization is restarted if the projection is optimized
with patch.object(w, "_run") as run, \
patch.object(w, "_OWFreeViz__optimized", new=True):
w.grav_slider.setValue(2)
self.assertEqual(w.projector.gravity, w.GravityValues[2])
run.assert_called_once()

# Also, check that checkbox also does all that
run.reset_mock()
w.controls.balance.click()
self.assertFalse(w.balance)
self.assertIsNone(w.projector.gravity)
run.assert_called_once()


class TestOWFreeVizRunner(unittest.TestCase):
@classmethod
Expand Down
Loading