diff --git a/orangecontrib/text/keywords/mbert.py b/orangecontrib/text/keywords/mbert.py index 64e4fb475..ee4024171 100644 --- a/orangecontrib/text/keywords/mbert.py +++ b/orangecontrib/text/keywords/mbert.py @@ -5,6 +5,7 @@ from typing import Optional, Callable, Tuple, List, Iterable import numpy as np +from Orange.misc.utils.embedder_utils import EmbeddingConnectionError from nltk import ngrams from Orange.misc.server_embedder import ServerEmbedderCommunicator from Orange.util import dummy_callback @@ -68,7 +69,10 @@ def mbert_keywords( server_url="https://api.garaza.io", embedder_type="text", ) - keywords = emb.embedd_data(documents, callback=progress_callback) + try: + keywords = emb.embedd_data(documents, callback=progress_callback) + except EmbeddingConnectionError: + keywords = [None] * len(documents) processed_kws = [] for kws in keywords: if kws is not None: diff --git a/orangecontrib/text/tests/test_keywords.py b/orangecontrib/text/tests/test_keywords.py index 2e494c487..c6e34d48e 100644 --- a/orangecontrib/text/tests/test_keywords.py +++ b/orangecontrib/text/tests/test_keywords.py @@ -4,6 +4,7 @@ import numpy as np from Orange.data import Domain, StringVariable +from Orange.misc.utils.embedder_utils import EmbeddingConnectionError from orangecontrib.text import Corpus from orangecontrib.text.keywords import ( @@ -180,6 +181,14 @@ def test_mbert_keywords(self, _): ] self.assertListEqual(expected, res) + @patch( + "orangecontrib.text.keywords.mbert._BertServerCommunicator.embedd_data", + side_effect=EmbeddingConnectionError, + ) + def test_mbert_keywords_fail(self, _): + res = mbert_keywords(["Text 1", "Text 2"], max_len=3) + self.assertListEqual([None, None], res) + @patch( "orangecontrib.text.keywords.mbert._BertServerCommunicator._send_request", diff --git a/orangecontrib/text/widgets/owkeywords.py b/orangecontrib/text/widgets/owkeywords.py index 209386568..58466e529 100644 --- a/orangecontrib/text/widgets/owkeywords.py +++ b/orangecontrib/text/widgets/owkeywords.py @@ -28,6 +28,10 @@ WORDS_COLUMN_NAME YAKE_LANGUAGES = list(YAKE_LANGUAGE_MAPPING.keys()) +CONNECTION_WARNING = ( + f"{ScoringMethods.MBERT} could not extract keywords from some " + "documents due to connection error. Please rerun keyword extraction." +) class Results(SimpleNamespace): @@ -37,6 +41,8 @@ class Results(SimpleNamespace): labels: List[str] = [] # all calculated keywords {method: [[(word1, score1), ...]]} all_keywords: Dict[str, List[List[Tuple[str, float]]]] = {} + # warnings happening during keyword extraction process + warnings: List[str] = [] def run( @@ -48,7 +54,7 @@ def run( agg_method: int, state: TaskState ) -> Results: - results = Results(scores=[], labels=[], all_keywords={}) + results = Results(scores=[], labels=[], all_keywords={}, warnings=[]) if not corpus: return results @@ -70,7 +76,8 @@ def callback(i: float, status=""): step = 1 / len(scoring_methods) for method_name, func in ScoringMethods.ITEMS: if method_name in scoring_methods: - if method_name not in results.all_keywords: + keywords = results.all_keywords.get(method_name) + if keywords is None: i = len(results.labels) cb = wrap_callback(callback, start=i * step, end=(i + 1) * step) @@ -79,10 +86,20 @@ def callback(i: float, status=""): kw = {"progress_callback": cb} kw.update(scoring_methods_kwargs.get(method_name, {})) - keywords = func(corpus if needs_tokens else documents, **kw) - results.all_keywords[method_name] = keywords + kws = func(corpus if needs_tokens else documents, **kw) + # None means that embedding completely failed on document + # currently it only happens with mbert when connection fails + keywords = [kws for kws in kws if kws is not None] + # don't store keywords to all_keywords if any were not computed + # due to connection issues; storing them would cause that + # missing keywords would not be recomputed on next run + # mbert's existing keywords are cached in embedding cache + # only missing will be recomputed + if len(kws) > len(keywords) and method_name == ScoringMethods.MBERT: + results.warnings.append(CONNECTION_WARNING) + else: + results.all_keywords[method_name] = keywords - keywords = results.all_keywords[method_name] scores[method_name] = \ dict(AggregationMethods.aggregate(keywords, agg_method)) @@ -210,6 +227,7 @@ class Outputs: class Warning(OWWidget.Warning): no_words_column = Msg("Input is missing 'Words' column.") + extraction_warnings = Msg("{}") def __init__(self): OWWidget.__init__(self) @@ -376,6 +394,7 @@ def handleNewSignals(self): self.update_scores() def update_scores(self): + self.Warning.extraction_warnings.clear() kwargs = { ScoringMethods.YAKE: { "language": YAKE_LANGUAGES[self.yake_lang_index], @@ -441,6 +460,8 @@ def on_done(self, results: Results): self._select_rows() else: self.__on_selection_changed() + if results.warnings: + self.Warning.extraction_warnings("\n".join(results.warnings)) def _apply_sorting(self): if self.model.columnCount() <= self.sort_column_order[0]: diff --git a/orangecontrib/text/widgets/tests/test_owkeywords.py b/orangecontrib/text/widgets/tests/test_owkeywords.py index ae301fbce..200e77246 100644 --- a/orangecontrib/text/widgets/tests/test_owkeywords.py +++ b/orangecontrib/text/widgets/tests/test_owkeywords.py @@ -13,8 +13,14 @@ from orangecontrib.text.keywords import tfidf_keywords, yake_keywords, \ rake_keywords from orangecontrib.text.preprocess import * -from orangecontrib.text.widgets.owkeywords import OWKeywords, run, \ - AggregationMethods, ScoringMethods, SelectionMethods +from orangecontrib.text.widgets.owkeywords import ( + OWKeywords, + run, + AggregationMethods, + ScoringMethods, + SelectionMethods, + CONNECTION_WARNING, +) from orangecontrib.text.widgets.utils.words import create_words_table @@ -111,6 +117,27 @@ def test_run_interrupt(self): {ScoringMethods.TF_IDF}, {}, AggregationMethods.MEAN, state) + def test_run_mbert_fail(self): + """Test mbert partially or completely fails due to connection issues""" + agg, sc = AggregationMethods.MEAN, {ScoringMethods.MBERT} + res = [[("keyword1", 10), ("keyword2", 2)], None, [("keyword1", 5)]] + with patch.object(ScoringMethods, "ITEMS", [("mBERT", Mock(return_value=res))]): + results = run(self.corpus[:3], None, {}, sc, {}, agg, self.state) + self.assertListEqual([["keyword1", 7.5], ["keyword2", 1]], results.scores) + self.assertListEqual(["mBERT"], results.labels) + # not stored to all_keywords since not all extracted exactly + self.assertDictEqual({}, results.all_keywords) + self.assertListEqual([CONNECTION_WARNING], results.warnings) + + res = [None] * 3 + with patch.object(ScoringMethods, "ITEMS", [("mBERT", Mock(return_value=res))]): + results = run(self.corpus[:3], None, {}, sc, {}, agg, self.state) + self.assertListEqual([], results.scores) + self.assertListEqual(["mBERT"], results.labels) + # not stored to all_keywords since not all extracted exactly + self.assertDictEqual({}, results.all_keywords) + self.assertListEqual([CONNECTION_WARNING], results.warnings) + def assertNanEqual(self, table1, table2): for list1, list2 in zip(table1, table2): for x1, x2 in zip(list1, list2): @@ -274,6 +301,38 @@ def test_selection_n_best(self): output = self.get_output(self.widget.Outputs.words) self.assertEqual(5, len(output)) + def test_connection_error(self): + self.widget.controlArea.findChildren(QCheckBox)[0].click() # unselect tfidf + self.widget.controlArea.findChildren(QCheckBox)[3].click() # unselect mbert + res = [[("keyword1", 10), ("keyword2", 2)], None, [("keyword1", 5)]] + with patch.object(ScoringMethods, "ITEMS", [("mBERT", Mock(return_value=res))]): + self.send_signal(self.widget.Inputs.corpus, self.corpus) + output = self.get_output(self.widget.Outputs.words) + self.assertEqual(len(output), 2) + np.testing.assert_array_equal(output.metas, [["keyword1"], ["keyword2"]]) + np.testing.assert_array_equal(output.X, [[7.5], [1]]) + self.assertTrue(self.widget.Warning.extraction_warnings.is_shown()) + self.assertEqual( + CONNECTION_WARNING, str(self.widget.Warning.extraction_warnings) + ) + + res = [None] * 3 # all failed + with patch.object(ScoringMethods, "ITEMS", [("mBERT", Mock(return_value=res))]): + self.send_signal(self.widget.Inputs.corpus, self.corpus) + self.assertIsNone(self.get_output(self.widget.Outputs.words)) + self.assertTrue(self.widget.Warning.extraction_warnings.is_shown()) + self.assertEqual( + CONNECTION_WARNING, str(self.widget.Warning.extraction_warnings) + ) + + res = [[("keyword1", 10), ("keyword2", 2)], [("keyword1", 5)]] + with patch.object(ScoringMethods, "ITEMS", [("mBERT", Mock(return_value=res))]): + self.send_signal(self.widget.Inputs.corpus, self.corpus) + output = self.get_output(self.widget.Outputs.words) + np.testing.assert_array_equal(output.metas, [["keyword1"], ["keyword2"]]) + np.testing.assert_array_equal(output.X, [[7.5], [1]]) + self.assertFalse(self.widget.Warning.extraction_warnings.is_shown()) + if __name__ == "__main__": unittest.main()