diff --git a/src/datatrove/pipeline/filters/language_filter.py b/src/datatrove/pipeline/filters/language_filter.py index b1573582..f43e5153 100644 --- a/src/datatrove/pipeline/filters/language_filter.py +++ b/src/datatrove/pipeline/filters/language_filter.py @@ -17,6 +17,7 @@ def __init__( exclusion_writer: DiskWriter = None, backend: Literal["ft176", "glotlid"] = "ft176", label_only: bool = False, + keep_top_pairs_threshold: float = -1, ): """ filters if the predicted language is not among given language or if the language score is below language @@ -27,6 +28,7 @@ def __init__( language_threshold: language_threshold minimum score to accept a document exclusion_writer: label_only: if True, only the language label is added to the metadata and no documents are removed + keep_top_pairs_threshold: keep a list of all language pairs with at least this score. -1 to disable """ super().__init__(exclusion_writer) self.language_threshold = language_threshold @@ -36,6 +38,7 @@ def __init__( self.backend = backend self.model = FT176LID(languages) if backend == "ft176" else GlotLID(languages) self.label_only = label_only + self.keep_top_pairs_threshold = keep_top_pairs_threshold def filter(self, doc: Document) -> bool: """Args: @@ -51,6 +54,10 @@ def filter(self, doc: Document) -> bool: doc.metadata["language_script"] = script doc.metadata["language"] = lang doc.metadata["language_score"] = lang_score + if self.keep_top_pairs_threshold != -1: + for key, value in lang_pairs.items(): + if value > self.keep_top_pairs_threshold: + doc.metadata[f"top_language_{key}_score"] = value return ( self.label_only or (self.languages and any(score > self.language_threshold for score in lang_pairs.values()))