Skip to content

Commit

Permalink
option to keep more language scores
Browse files Browse the repository at this point in the history
  • Loading branch information
guipenedo committed Jul 5, 2024
1 parent 7ba873f commit 061d4db
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/datatrove/pipeline/filters/language_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(
exclusion_writer: DiskWriter = None,
backend: Literal["ft176", "glotlid"] = "ft176",
label_only: bool = False,
keep_top_pairs_threshold: float = -1,
):
"""
filters if the predicted language is not among given language or if the language score is below language
Expand All @@ -27,6 +28,7 @@ def __init__(
language_threshold: language_threshold minimum score to accept a document
exclusion_writer:
label_only: if True, only the language label is added to the metadata and no documents are removed
keep_top_pairs_threshold: keep a list of all language pairs with at least this score. -1 to disable
"""
super().__init__(exclusion_writer)
self.language_threshold = language_threshold
Expand All @@ -36,6 +38,7 @@ def __init__(
self.backend = backend
self.model = FT176LID(languages) if backend == "ft176" else GlotLID(languages)
self.label_only = label_only
self.keep_top_pairs_threshold = keep_top_pairs_threshold

def filter(self, doc: Document) -> bool:
"""Args:
Expand All @@ -51,6 +54,10 @@ def filter(self, doc: Document) -> bool:
doc.metadata["language_script"] = script
doc.metadata["language"] = lang
doc.metadata["language_score"] = lang_score
if self.keep_top_pairs_threshold != -1:
for key, value in lang_pairs.items():
if value > self.keep_top_pairs_threshold:
doc.metadata[f"top_language_{key}_score"] = value
return (
self.label_only
or (self.languages and any(score > self.language_threshold for score in lang_pairs.values()))
Expand Down

0 comments on commit 061d4db

Please sign in to comment.