From a3bca9355dbdd90945f71209cc6520beac844ba3 Mon Sep 17 00:00:00 2001
From: PrimozGodec
Date: Tue, 1 Aug 2023 13:54:51 +0200
Subject: [PATCH 1/2] MBert Keywords - Handle failed embedding due to
connection error
---
orangecontrib/text/keywords/mbert.py | 6 +++++-
orangecontrib/text/tests/test_keywords.py | 9 +++++++++
2 files changed, 14 insertions(+), 1 deletion(-)
diff --git a/orangecontrib/text/keywords/mbert.py b/orangecontrib/text/keywords/mbert.py
index 64e4fb475..ee4024171 100644
--- a/orangecontrib/text/keywords/mbert.py
+++ b/orangecontrib/text/keywords/mbert.py
@@ -5,6 +5,7 @@
from typing import Optional, Callable, Tuple, List, Iterable
import numpy as np
+from Orange.misc.utils.embedder_utils import EmbeddingConnectionError
from nltk import ngrams
from Orange.misc.server_embedder import ServerEmbedderCommunicator
from Orange.util import dummy_callback
@@ -68,7 +69,10 @@ def mbert_keywords(
server_url="https://api.garaza.io",
embedder_type="text",
)
- keywords = emb.embedd_data(documents, callback=progress_callback)
+ try:
+ keywords = emb.embedd_data(documents, callback=progress_callback)
+ except EmbeddingConnectionError:
+ keywords = [None] * len(documents)
processed_kws = []
for kws in keywords:
if kws is not None:
diff --git a/orangecontrib/text/tests/test_keywords.py b/orangecontrib/text/tests/test_keywords.py
index 2e494c487..c6e34d48e 100644
--- a/orangecontrib/text/tests/test_keywords.py
+++ b/orangecontrib/text/tests/test_keywords.py
@@ -4,6 +4,7 @@
import numpy as np
from Orange.data import Domain, StringVariable
+from Orange.misc.utils.embedder_utils import EmbeddingConnectionError
from orangecontrib.text import Corpus
from orangecontrib.text.keywords import (
@@ -180,6 +181,14 @@ def test_mbert_keywords(self, _):
]
self.assertListEqual(expected, res)
+ @patch(
+ "orangecontrib.text.keywords.mbert._BertServerCommunicator.embedd_data",
+ side_effect=EmbeddingConnectionError,
+ )
+ def test_mbert_keywords_fail(self, _):
+ res = mbert_keywords(["Text 1", "Text 2"], max_len=3)
+ self.assertListEqual([None, None], res)
+
@patch(
"orangecontrib.text.keywords.mbert._BertServerCommunicator._send_request",
From 9867d3c3f9cb108913c45cf7a4a43698a9a7b44c Mon Sep 17 00:00:00 2001
From: PrimozGodec
Date: Wed, 2 Aug 2023 12:25:29 +0200
Subject: [PATCH 2/2] Keywords - Handle failed embedding due to connection
error
---
orangecontrib/text/widgets/owkeywords.py | 31 +++++++--
.../text/widgets/tests/test_owkeywords.py | 63 ++++++++++++++++++-
2 files changed, 87 insertions(+), 7 deletions(-)
diff --git a/orangecontrib/text/widgets/owkeywords.py b/orangecontrib/text/widgets/owkeywords.py
index 209386568..58466e529 100644
--- a/orangecontrib/text/widgets/owkeywords.py
+++ b/orangecontrib/text/widgets/owkeywords.py
@@ -28,6 +28,10 @@
WORDS_COLUMN_NAME
YAKE_LANGUAGES = list(YAKE_LANGUAGE_MAPPING.keys())
+CONNECTION_WARNING = (
+ f"{ScoringMethods.MBERT} could not extract keywords from some "
+ "documents due to connection error. Please rerun keyword extraction."
+)
class Results(SimpleNamespace):
@@ -37,6 +41,8 @@ class Results(SimpleNamespace):
labels: List[str] = []
# all calculated keywords {method: [[(word1, score1), ...]]}
all_keywords: Dict[str, List[List[Tuple[str, float]]]] = {}
+ # warnings happening during keyword extraction process
+ warnings: List[str] = []
def run(
@@ -48,7 +54,7 @@ def run(
agg_method: int,
state: TaskState
) -> Results:
- results = Results(scores=[], labels=[], all_keywords={})
+ results = Results(scores=[], labels=[], all_keywords={}, warnings=[])
if not corpus:
return results
@@ -70,7 +76,8 @@ def callback(i: float, status=""):
step = 1 / len(scoring_methods)
for method_name, func in ScoringMethods.ITEMS:
if method_name in scoring_methods:
- if method_name not in results.all_keywords:
+ keywords = results.all_keywords.get(method_name)
+ if keywords is None:
i = len(results.labels)
cb = wrap_callback(callback, start=i * step,
end=(i + 1) * step)
@@ -79,10 +86,20 @@ def callback(i: float, status=""):
kw = {"progress_callback": cb}
kw.update(scoring_methods_kwargs.get(method_name, {}))
- keywords = func(corpus if needs_tokens else documents, **kw)
- results.all_keywords[method_name] = keywords
+ kws = func(corpus if needs_tokens else documents, **kw)
+ # None means that embedding completely failed on document
+ # currently it only happens with mbert when connection fails
+ keywords = [kws for kws in kws if kws is not None]
+ # don't store keywords to all_keywords if any were not computed
+ # due to connection issues; storing them would cause that
+ # missing keywords would not be recomputed on next run
+ # mbert's existing keywords are cached in embedding cache
+ # only missing will be recomputed
+ if len(kws) > len(keywords) and method_name == ScoringMethods.MBERT:
+ results.warnings.append(CONNECTION_WARNING)
+ else:
+ results.all_keywords[method_name] = keywords
- keywords = results.all_keywords[method_name]
scores[method_name] = \
dict(AggregationMethods.aggregate(keywords, agg_method))
@@ -210,6 +227,7 @@ class Outputs:
class Warning(OWWidget.Warning):
no_words_column = Msg("Input is missing 'Words' column.")
+ extraction_warnings = Msg("{}")
def __init__(self):
OWWidget.__init__(self)
@@ -376,6 +394,7 @@ def handleNewSignals(self):
self.update_scores()
def update_scores(self):
+ self.Warning.extraction_warnings.clear()
kwargs = {
ScoringMethods.YAKE: {
"language": YAKE_LANGUAGES[self.yake_lang_index],
@@ -441,6 +460,8 @@ def on_done(self, results: Results):
self._select_rows()
else:
self.__on_selection_changed()
+ if results.warnings:
+ self.Warning.extraction_warnings("\n".join(results.warnings))
def _apply_sorting(self):
if self.model.columnCount() <= self.sort_column_order[0]:
diff --git a/orangecontrib/text/widgets/tests/test_owkeywords.py b/orangecontrib/text/widgets/tests/test_owkeywords.py
index ae301fbce..200e77246 100644
--- a/orangecontrib/text/widgets/tests/test_owkeywords.py
+++ b/orangecontrib/text/widgets/tests/test_owkeywords.py
@@ -13,8 +13,14 @@
from orangecontrib.text.keywords import tfidf_keywords, yake_keywords, \
rake_keywords
from orangecontrib.text.preprocess import *
-from orangecontrib.text.widgets.owkeywords import OWKeywords, run, \
- AggregationMethods, ScoringMethods, SelectionMethods
+from orangecontrib.text.widgets.owkeywords import (
+ OWKeywords,
+ run,
+ AggregationMethods,
+ ScoringMethods,
+ SelectionMethods,
+ CONNECTION_WARNING,
+)
from orangecontrib.text.widgets.utils.words import create_words_table
@@ -111,6 +117,27 @@ def test_run_interrupt(self):
{ScoringMethods.TF_IDF}, {},
AggregationMethods.MEAN, state)
+ def test_run_mbert_fail(self):
+ """Test mbert partially or completely fails due to connection issues"""
+ agg, sc = AggregationMethods.MEAN, {ScoringMethods.MBERT}
+ res = [[("keyword1", 10), ("keyword2", 2)], None, [("keyword1", 5)]]
+ with patch.object(ScoringMethods, "ITEMS", [("mBERT", Mock(return_value=res))]):
+ results = run(self.corpus[:3], None, {}, sc, {}, agg, self.state)
+ self.assertListEqual([["keyword1", 7.5], ["keyword2", 1]], results.scores)
+ self.assertListEqual(["mBERT"], results.labels)
+ # not stored to all_keywords since not all extracted exactly
+ self.assertDictEqual({}, results.all_keywords)
+ self.assertListEqual([CONNECTION_WARNING], results.warnings)
+
+ res = [None] * 3
+ with patch.object(ScoringMethods, "ITEMS", [("mBERT", Mock(return_value=res))]):
+ results = run(self.corpus[:3], None, {}, sc, {}, agg, self.state)
+ self.assertListEqual([], results.scores)
+ self.assertListEqual(["mBERT"], results.labels)
+ # not stored to all_keywords since not all extracted exactly
+ self.assertDictEqual({}, results.all_keywords)
+ self.assertListEqual([CONNECTION_WARNING], results.warnings)
+
def assertNanEqual(self, table1, table2):
for list1, list2 in zip(table1, table2):
for x1, x2 in zip(list1, list2):
@@ -274,6 +301,38 @@ def test_selection_n_best(self):
output = self.get_output(self.widget.Outputs.words)
self.assertEqual(5, len(output))
+ def test_connection_error(self):
+ self.widget.controlArea.findChildren(QCheckBox)[0].click() # unselect tfidf
+ self.widget.controlArea.findChildren(QCheckBox)[3].click() # unselect mbert
+ res = [[("keyword1", 10), ("keyword2", 2)], None, [("keyword1", 5)]]
+ with patch.object(ScoringMethods, "ITEMS", [("mBERT", Mock(return_value=res))]):
+ self.send_signal(self.widget.Inputs.corpus, self.corpus)
+ output = self.get_output(self.widget.Outputs.words)
+ self.assertEqual(len(output), 2)
+ np.testing.assert_array_equal(output.metas, [["keyword1"], ["keyword2"]])
+ np.testing.assert_array_equal(output.X, [[7.5], [1]])
+ self.assertTrue(self.widget.Warning.extraction_warnings.is_shown())
+ self.assertEqual(
+ CONNECTION_WARNING, str(self.widget.Warning.extraction_warnings)
+ )
+
+ res = [None] * 3 # all failed
+ with patch.object(ScoringMethods, "ITEMS", [("mBERT", Mock(return_value=res))]):
+ self.send_signal(self.widget.Inputs.corpus, self.corpus)
+ self.assertIsNone(self.get_output(self.widget.Outputs.words))
+ self.assertTrue(self.widget.Warning.extraction_warnings.is_shown())
+ self.assertEqual(
+ CONNECTION_WARNING, str(self.widget.Warning.extraction_warnings)
+ )
+
+ res = [[("keyword1", 10), ("keyword2", 2)], [("keyword1", 5)]]
+ with patch.object(ScoringMethods, "ITEMS", [("mBERT", Mock(return_value=res))]):
+ self.send_signal(self.widget.Inputs.corpus, self.corpus)
+ output = self.get_output(self.widget.Outputs.words)
+ np.testing.assert_array_equal(output.metas, [["keyword1"], ["keyword2"]])
+ np.testing.assert_array_equal(output.X, [[7.5], [1]])
+ self.assertFalse(self.widget.Warning.extraction_warnings.is_shown())
+
if __name__ == "__main__":
unittest.main()