From c5cebb55358312f7a4d77479f6001e2da0c79dfc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Haukur=20P=C3=A1ll?= Date: Wed, 25 Oct 2023 13:37:48 +0000 Subject: [PATCH] Fixing eval.py to support GreynirCorrect v4.0.0 --- eval/eval.py | 280 +++++++++++++++------------------------------------ 1 file changed, 82 insertions(+), 198 deletions(-) diff --git a/eval/eval.py b/eval/eval.py index 52e1965..9c512a7 100755 --- a/eval/eval.py +++ b/eval/eval.py @@ -85,29 +85,38 @@ $ python eval.py -a """ - from typing import ( TYPE_CHECKING, + Any, + Counter, + DefaultDict, Dict, + Iterable, List, Optional, Set, - Union, Tuple, - Iterable, + Union, cast, - Any, - DefaultDict, - Counter, ) -import os -from collections import defaultdict -from datetime import datetime +import argparse import glob +import os import random -import argparse import xml.etree.ElementTree as ET +from collections import defaultdict +from datetime import datetime + +from reynir_correct import ( + Annotation, + CorrectedSentence, + CorrectionPipeline, + GreynirCorrect, + GreynirCorrectAPI, + Settings, +) +from tokenizer import TOK, Tok, detokenize if TYPE_CHECKING: # For some reason, types seem to be missing from the multiprocessing module @@ -116,17 +125,6 @@ else: import multiprocessing -from reynir import _Sentence -from tokenizer import detokenize, Tok, TOK - -from reynir_correct.annotation import Annotation -from reynir_correct.checker import ( - GreynirCorrect, - Settings, - AnnotatedSentence, - check as gc_check, -) - # Disable Pylint warnings arising from Pylint not understanding the typing module # pylint: disable=no-member @@ -144,9 +142,7 @@ CategoryStatsDict = DefaultDict[str, SentenceStatsDict] # This tuple should agree with the parameters of the add_sentence() function -StatsTuple = Tuple[ - str, int, bool, bool, int, int, int, int, int, int, int, int, int, int, int, int -] +StatsTuple = Tuple[str, int, bool, bool, int, int, int, int, int, int, int, int, int, int, int, int] # Counter of tp, tn, right_corr, wrong_corr, right_span, wrong_span TypeFreqs = Counter[str] @@ -159,7 +155,8 @@ settings = Settings() settings.read(os.path.join("config", "GreynirCorrect.conf")) -rc = GreynirCorrect(settings) +gc = GreynirCorrect(settings, pipeline=CorrectionPipeline("", settings=settings)) +rc = GreynirCorrectAPI(gc=gc) # Create a lock to ensure that only one process outputs at a time OUTPUT_LOCK = multiprocessing.Lock() @@ -314,9 +311,7 @@ # Three levels: Supercategories, subcategories and error codes # supercategory: {subcategory : [error code]} -SUPERCATEGORIES: DefaultDict[str, DefaultDict[str, List[str]]] = defaultdict( - lambda: defaultdict(list) -) +SUPERCATEGORIES: DefaultDict[str, DefaultDict[str, List[str]]] = defaultdict(lambda: defaultdict(list)) GCtoIEC = { "A001": ["abbreviation-period"], @@ -595,8 +590,7 @@ parser = argparse.ArgumentParser( description=( - "This program evaluates the spelling and grammar checking performance " - "of GreynirCorrect on iceErrorCorpus" + "This program evaluates the spelling and grammar checking performance " "of GreynirCorrect on iceErrorCorpus" ) ) @@ -674,9 +668,7 @@ help="Create an analysis report for token results", ) -parser.add_argument( - "-f", "--catfile", type=str, default="iceErrorCorpus/errorCodes.tsv" -) +parser.add_argument("-f", "--catfile", type=str, default="iceErrorCorpus/errorCodes.tsv") # This boolean global is set to True for quiet output, # which is the default when processing the test corpus @@ -710,9 +702,7 @@ def __init__(self) -> None: self._files: Dict[str, int] = defaultdict(int) # We employ a trick to make the defaultdicts picklable between processes: # instead of the usual lambda: defaultdict(int), use defaultdict(int).copy - self._sentences: CategoryStatsDict = CategoryStatsDict( - SentenceStatsDict(int).copy - ) + self._sentences: CategoryStatsDict = CategoryStatsDict(SentenceStatsDict(int).copy) self._errtypes: ErrTypeStatsDict = ErrTypeStatsDict(Counter) self._true_positives: DefaultDict[str, int] = defaultdict(int) self._false_negatives: DefaultDict[str, int] = defaultdict(int) @@ -823,9 +813,7 @@ def output(self, cores: int) -> None: if SINGLE: bprint(f"") - num_sentences: int = sum( - cast(int, d["count"]) for d in self._sentences.values() - ) + num_sentences: int = sum(cast(int, d["count"]) for d in self._sentences.values()) def output_duration() -> None: # type: ignore """Calculate the duration of the processing""" @@ -859,14 +847,10 @@ def perc(n: int, whole: int) -> str: return "N/A" return f"{100.0*n/whole:3.2f}" - def write_basic_value( - val: int, bv: str, whole: int, errwhole: Optional[int] = None - ) -> None: + def write_basic_value(val: int, bv: str, whole: int, errwhole: Optional[int] = None) -> None: """Write basic values for sentences and their freqs to stdout""" if errwhole: - bprint( - f"\n{NAMES[bv]+':':<20} {val:6} {perc(val, whole):>6}% / {perc(val, errwhole):>6}%" - ) + bprint(f"\n{NAMES[bv]+':':<20} {val:6} {perc(val, whole):>6}% / {perc(val, errwhole):>6}%") else: bprint(f"\n{NAMES[bv]+':':<20} {val:6} {perc(val, whole):>6}%") for c in GENRES: @@ -939,9 +923,7 @@ def calc_PRF( else: bprint(f" {c:<13}: N/A") - def calc_recall( - right: int, wrong: int, rights: str, wrongs: str, recs: str - ) -> None: + def calc_recall(right: int, wrong: int, rights: str, wrongs: str, recs: str) -> None: """Calculate precision for binary classification""" # Recall if right + wrong == 0: @@ -1000,9 +982,7 @@ def calc_error_category_metrics(cat: str) -> CatResultDict: precision = catdict["precision"] = tp / (tp + fp) # F0.5 score if recall + precision > 0.0: - catdict["f05score"] = ( - 1.25 * (precision * recall) / (0.25 * precision + recall) - ) + catdict["f05score"] = 1.25 * (precision * recall) / (0.25 * precision + recall) else: catdict["f05score"] = NO_RESULTS # Error correction metrics @@ -1014,26 +994,20 @@ def calc_error_category_metrics(cat: str) -> CatResultDict: cprecision = catdict["cprecision"] = ctp / (ctp + cfp) # F0.5 score if crecall + cprecision > 0.0: - catdict["cf05score"] = ( - 1.25 * (cprecision * crecall) / (0.25 * cprecision + crecall) - ) + catdict["cf05score"] = 1.25 * (cprecision * crecall) / (0.25 * cprecision + crecall) else: catdict["cf05score"] = NO_RESULTS # Correction recall (not used) right_corr = cast(int, catdict.get("right_corr", 0)) if right_corr > 0: - catdict["corr_rec"] = right_corr / ( - right_corr + cast(int, catdict.get("wrong_corr", 0)) - ) + catdict["corr_rec"] = right_corr / (right_corr + cast(int, catdict.get("wrong_corr", 0))) else: catdict["corr_rec"] = -1.0 # Span recall right_span = cast(int, catdict.get("right_span", 0)) if right_span > 0: - catdict["span_rec"] = right_span / ( - right_span + cast(int, catdict.get("wrong_span", 0)) - ) + catdict["span_rec"] = right_span / (right_span + cast(int, catdict.get("wrong_span", 0))) else: catdict["span_rec"] = NO_RESULTS return catdict @@ -1043,18 +1017,10 @@ def output_sentence_scores() -> None: # type: ignore # Total number of true negatives found bprint(f"\nResults for error detection for whole sentences") - true_positives: int = sum( - cast(int, d["true_positives"]) for d in self._sentences.values() - ) - true_negatives: int = sum( - cast(int, d["true_negatives"]) for d in self._sentences.values() - ) - false_positives: int = sum( - cast(int, d["false_positives"]) for d in self._sentences.values() - ) - false_negatives: int = sum( - cast(int, d["false_negatives"]) for d in self._sentences.values() - ) + true_positives: int = sum(cast(int, d["true_positives"]) for d in self._sentences.values()) + true_negatives: int = sum(cast(int, d["true_negatives"]) for d in self._sentences.values()) + false_positives: int = sum(cast(int, d["false_positives"]) for d in self._sentences.values()) + false_negatives: int = sum(cast(int, d["false_negatives"]) for d in self._sentences.values()) write_basic_value(true_positives, "true_positives", num_sentences) write_basic_value(true_negatives, "true_negatives", num_sentences) @@ -1067,12 +1033,7 @@ def output_sentence_scores() -> None: # type: ignore if num_sentences == 0: result = "N/A" else: - result = ( - perc(true_results, num_sentences) - + "%/" - + perc(false_results, num_sentences) - + "%" - ) + result = perc(true_results, num_sentences) + "%/" + perc(false_results, num_sentences) + "%" bprint(f"\nTrue/false split: {result:>16}") for c in GENRES: d = self._sentences[c] @@ -1126,9 +1087,7 @@ def output_token_scores() -> None: # type: ignore bprint(f"\n\nResults for error detection within sentences") - num_tokens = sum( - cast(int, d["num_tokens"]) for d in self._sentences.values() - ) + num_tokens = sum(cast(int, d["num_tokens"]) for d in self._sentences.values()) bprint(f"\nTokens processed: {num_tokens:6}") for c in GENRES: bprint(f" {c:<13}: {self._sentences[c]['num_tokens']:6}") @@ -1161,35 +1120,23 @@ def output_token_scores() -> None: # type: ignore # Loose: Of all errors the tool correctly finds, how many get the right correction? # Can only calculate recall. bprint(f"\nResults for error correction") - right_corr = sum( - cast(int, d["right_corr"]) for d in self._sentences.values() - ) - wrong_corr = sum( - cast(int, d["wrong_corr"]) for d in self._sentences.values() - ) + right_corr = sum(cast(int, d["right_corr"]) for d in self._sentences.values()) + wrong_corr = sum(cast(int, d["wrong_corr"]) for d in self._sentences.values()) write_basic_value(right_corr, "right_corr", num_tokens, tp) write_basic_value(wrong_corr, "wrong_corr", num_tokens, tp) - calc_recall( - right_corr, wrong_corr, "right_corr", "wrong_corr", "correctrecall" - ) + calc_recall(right_corr, wrong_corr, "right_corr", "wrong_corr", "correctrecall") # Stiff: Of all errors in error corpora, how many get the right span? # Loose: Of all errors the tool correctly finds, how many get the right span? # Can only calculate recall. bprint(f"\nResults for error span") - right_span = sum( - cast(int, d["right_span"]) for d in self._sentences.values() - ) - wrong_span = sum( - cast(int, d["wrong_span"]) for d in self._sentences.values() - ) + right_span = sum(cast(int, d["right_span"]) for d in self._sentences.values()) + wrong_span = sum(cast(int, d["wrong_span"]) for d in self._sentences.values()) write_basic_value(right_span, "right_span", num_tokens, tp) write_basic_value(wrong_span, "wrong_span", num_tokens, tp) - calc_recall( - right_span, wrong_span, "right_span", "wrong_span", "spanrecall" - ) + calc_recall(right_span, wrong_span, "right_span", "wrong_span", "spanrecall") def output_error_cat_scores() -> None: """Calculate and write scores for each error category to stdout""" @@ -1234,10 +1181,7 @@ def output_error_cat_scores() -> None: cast(float, rk.get("f05score", 0.0)) * 100.0, ) ) - if ( - rk.get("corr_rec", "N/A") == "N/A" - or rk.get("span_rec", "N/A") == "N/A" - ): + if rk.get("corr_rec", "N/A") == "N/A" or rk.get("span_rec", "N/A") == "N/A": bprint("\tCorr, span: N/A, N/A") else: bprint( @@ -1263,9 +1207,7 @@ def output_supercategory_scores(): in SUPERCATEGORIES, each subcategory, and error code""" bprint("Supercategory: frequency, F-score") bprint("\tSubcategory: frequency, F-score") - bprint( - "\t\tError code: frequency, (recall, precision, F-score), (tp, fn, fp)| correct recall" - ) + bprint("\t\tError code: frequency, (recall, precision, F-score), (tp, fn, fp)| correct recall") totalfreq = 0 totalf = 0.0 for supercat in SUPERCATEGORIES: @@ -1287,40 +1229,24 @@ def output_supercategory_scores(): freq = cast(int, et["freq"]) fscore = cast(float, et["f05score"]) # codework - subblob = ( - subblob - + "\t\t{} {} ({:3.2f}, {:3.2f}, {:3.2f}) ({},{},{})| {}\n".format( - code, - freq, - cast(float, et["recall"]) * 100.0 - if "recall" in et - else 0.0, - cast(float, et["precision"]) * 100.0 - if "precision" in et - else 0.0, - fscore * 100.0, - cast(int, et["tp"]) if "tp" in et else 0, - cast(int, et["fn"]) if "fn" in et else 0, - cast(int, et["fp"]) if "fp" in et else 0, - cast(float, et["corr_rec"]) - if "corr_rec" in et - else 0.0, - ) + subblob = subblob + "\t\t{} {} ({:3.2f}, {:3.2f}, {:3.2f}) ({},{},{})| {}\n".format( + code, + freq, + cast(float, et["recall"]) * 100.0 if "recall" in et else 0.0, + cast(float, et["precision"]) * 100.0 if "precision" in et else 0.0, + fscore * 100.0, + cast(int, et["tp"]) if "tp" in et else 0, + cast(int, et["fn"]) if "fn" in et else 0, + cast(int, et["fp"]) if "fp" in et else 0, + cast(float, et["corr_rec"]) if "corr_rec" in et else 0.0, ) # subwork subfreq += freq subf += fscore * freq * 100.0 if subfreq != 0: - subblob = ( - "\t{} {} {}\n".format( - subcat.capitalize(), subfreq, subf / subfreq - ) - + subblob - ) + subblob = "\t{} {} {}\n".format(subcat.capitalize(), subfreq, subf / subfreq) + subblob else: - subblob = ( - "\t{} 0 N/A\n".format(subcat.capitalize()) + subblob - ) + subblob = "\t{} 0 N/A\n".format(subcat.capitalize()) + subblob # superwork # freq, f05 superblob += subblob @@ -1328,15 +1254,10 @@ def output_supercategory_scores(): superf += subf # TODO is this correct? if superfreq != 0: superblob = ( - "\n{} {} {}\n".format( - supercat.capitalize(), superfreq, superf / superfreq - ) - + superblob + "\n{} {} {}\n".format(supercat.capitalize(), superfreq, superf / superfreq) + superblob ) else: - superblob = ( - "\n{} 0 N/A\n".format(supercat.capitalize()) + superblob - ) + superblob = "\n{} 0 N/A\n".format(supercat.capitalize()) + superblob totalfreq += superfreq totalf += superf # TODO is this correct? bprint("".join(superblob)) @@ -1412,10 +1333,7 @@ def output_all_scores(): if ("recall" in et and float(et["recall"]) > 0.0) else NO_RESULTS, # Or "N/A", but that messes with the f-string formatting cast(float, et["precision"]) * 100.0 - if ( - "precision" in et - and float(et["precision"]) > 0.0 - ) + if ("precision" in et and float(et["precision"]) > 0.0) else NO_RESULTS, fscore * 100.0 if fscore > 0.0 else NO_RESULTS, cast(int, et["ctp"]) if "ctp" in et else 0, @@ -1425,10 +1343,7 @@ def output_all_scores(): if ("crecall" in et and float(et["crecall"]) > 0.0) else NO_RESULTS, cast(float, et["cprecision"]) * 100.0 - if ( - "cprecision" in et - and float(et["cprecision"]) > 0.0 - ) + if ("cprecision" in et and float(et["cprecision"]) > 0.0) else NO_RESULTS, cfscore * 100.0 if cfscore > 0.0 else NO_RESULTS, ) @@ -1459,9 +1374,7 @@ def output_all_scores(): ) subcprecision += ( cast(float, et["cprecision"]) * freq * 100.0 - if ( - "cprecision" in et and float(et["cprecision"]) > 0.0 - ) + if ("cprecision" in et and float(et["cprecision"]) > 0.0) else 0.0 ) subcf += cfscore * freq * 100.0 if cfscore > 0.0 else 0.0 @@ -1613,10 +1526,7 @@ def output_all_scores(): def correct_spaces(tokens: List[Tuple[str, str]]) -> str: """Returns a string with a reasonably correct concatenation of the tokens, where each token is a (tag, text) tuple.""" - return detokenize( - Tok(TOK.PUNCTUATION if tag == "c" else TOK.WORD, txt, None) - for tag, txt in tokens - ) + return detokenize(Tok(TOK.PUNCTUATION if tag == "c" else TOK.WORD, txt, None) for tag, txt in tokens) # Accumulate standard output in a buffer, for writing in one fell @@ -1630,7 +1540,6 @@ def bprint(s: str): def process(fpath_and_category: Tuple[str, str]) -> Dict[str, Any]: - """Process a single error corpus file in TEI XML format. This function is called within a multiprocessing pool and therefore usually executes in a child process, separate @@ -1660,7 +1569,6 @@ def process(fpath_and_category: Tuple[str, str]) -> Dict[str, Any]: errtypefreqs: ErrTypeStatsDict = ErrTypeStatsDict(TypeFreqs().copy) try: - if not QUIET: # Output a file header bprint("-" * 64) @@ -1711,9 +1619,7 @@ def process(fpath_and_category: Tuple[str, str]) -> Dict[str, Any]: if el_orig is not None: # We have 0 or more original tokens embedded # within the revision tag - orig_tokens = [ - (subel.tag[nl:], element_text(subel)) for subel in el_orig - ] + orig_tokens = [(subel.tag[nl:], element_text(subel)) for subel in el_orig] tokens.extend(orig_tokens) original = " ".join(t[1] for t in orig_tokens).strip() # Calculate the index of the ending token within the span @@ -1755,9 +1661,7 @@ def process(fpath_and_category: Tuple[str, str]) -> Dict[str, Any]: else: if QUIET: bprint(f"In file {fpath}:") - bprint( - f"\n{index}: *** 'depId' attribute missing for dependency ***" - ) + bprint(f"\n{index}: *** 'depId' attribute missing for dependency ***") if SINGLE and xtype == SINGLE: check = True else: @@ -1782,20 +1686,13 @@ def process(fpath_and_category: Tuple[str, str]) -> Dict[str, Any]: # Nothing to do: drop this and go to the next sentence continue # print(text) - options = {} - options["annotate_unparsed_sentences"] = True # True is default - options["suppress_suggestions"] = False # False is default - options["ignore_rules"] = set( - [ - "", - ] - ) # Pass it to GreynirCorrect - pg = [list(p) for p in gc_check(text, rc=rc, **options)] - s: Optional[_Sentence] = None - if len(pg) >= 1 and len(pg[0]) >= 1: - s = pg[0][0] - if len(pg) > 1 or (len(pg) == 1 and len(pg[0]) > 1): + result = rc.correct(text=text, suppress_suggestions=False, ignore_rules=set()) + pg = result.sentences + s: Optional[CorrectedSentence] = None + if len(pg) >= 1: + s = pg[0] + if len(pg) > 1 or (len(pg) == 1): # if QUIET: # bprint(f"In file {fpath}:") # bprint( @@ -1845,9 +1742,7 @@ def sentence_results( if unparsable: ups[xtype] += 1 if not QUIET: - bprint( - f"<<< {err['start']:03}-{err['end']:03}: {asterisk}{xtype}" - ) + bprint(f"<<< {err['start']:03}-{err['end']:03}: {asterisk}{xtype}") if not QUIET: # Output true/false positive/negative result if ice_error and gc_error: @@ -1867,8 +1762,8 @@ def sentence_results( return gc_error, ice_error assert s is not None - assert isinstance(s, AnnotatedSentence) - gc_error, ice_error = sentence_results(s.annotations, errors) + assert isinstance(s, CorrectedSentence) + gc_error, ice_error = sentence_results(s.annotations or [], errors) def token_results( hyp_annotations: Iterable[Annotation], @@ -1979,20 +1874,14 @@ def token_results( # Multiple tags for same error: Skip rest if xspan == xspanlast: if ANALYSIS: - analysisblob.append( - "\t Same span, skip: {}".format( - cast(str, xtok["xtype"]) - ) - ) + analysisblob.append("\t Same span, skip: {}".format(cast(str, xtok["xtype"]))) xtok = None xtok = next(x) continue if ytok.code in GCSKIPCODES: # Skip these errors, shouldn't be compared. if ANALYSIS: - analysisblob.append( - "\t Skip: {}".format(ytok.code) - ) + analysisblob.append("\t Skip: {}".format(ytok.code)) ytok = None ytok = next(y) continue @@ -2070,9 +1959,7 @@ def token_results( if ytok.code in GCSKIPCODES: # Skip these errors, shouldn't be a part of the results. if ANALYSIS: - analysisblob.append( - "\t Skip: {}".format(ytok.code) - ) + analysisblob.append("\t Skip: {}".format(ytok.code)) ytok = next(y, None) continue fp += 1 @@ -2098,9 +1985,7 @@ def token_results( if xspan == xspanlast: # Multiple tags for same error: Skip rest if ANALYSIS: - analysisblob.append( - "\t Same span, skip: {}".format(xtype) - ) + analysisblob.append("\t Same span, skip: {}".format(xtype)) xtok = None xtok = next(x, None) else: @@ -2126,7 +2011,6 @@ def token_results( wrong_span, ) - assert isinstance(s, AnnotatedSentence) ( tp, fp, @@ -2138,7 +2022,7 @@ def token_results( cfn, right_span, wrong_span, - ) = token_results(s.annotations, errors) + ) = token_results(s.annotations or [], errors) tn = len(tokens) - tp - fp - fn ctn = len(tokens) - ctp - cfp - cfn # Collect statistics into the stats list, to be returned