-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[LLM pipeline] Language filter component (#232)
This PR adds the first component for the LLM dataset creation pipeline. The component is a language filter which filters out rows in a provided dataframe that are not matching the provided language. FastText is used for the language detection. Changes - add component - add unit test to test the filter logic inside of the component Note: Did not create a pipeline that uses this component yet. --------- Co-authored-by: NielsRogge <[email protected]> Co-authored-by: Robbe Sneyders <[email protected]>
- Loading branch information
1 parent
b544cd4
commit d06b9e0
Showing
7 changed files
with
169 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
FROM --platform=linux/amd64 python:3.8-slim | ||
|
||
## System dependencies | ||
RUN apt-get update && \ | ||
apt-get upgrade -y && \ | ||
apt-get install git -y | ||
|
||
# install requirements | ||
COPY requirements.txt / | ||
RUN pip3 install --no-cache-dir -r requirements.txt | ||
|
||
# Set the working directory to the component folder | ||
WORKDIR /component/src | ||
|
||
# Copy over src-files | ||
COPY src/ . | ||
|
||
ENTRYPOINT ["python", "main.py"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Language filter | ||
|
||
## Description | ||
This component is based on the `TransformComponent` and is used to filter a dataframe based on language. | ||
It allows you to remove rows that do not match the provided language, thus providing a way to focus | ||
on specific languages within your data. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
name: Filter languages | ||
description: A component that filters text based on the language. | ||
image: ghcr.io/ml6team/filter_language:latest | ||
|
||
consumes: | ||
text: | ||
fields: | ||
data: | ||
type: string | ||
|
||
args: | ||
language: | ||
description: A valid language code or identifier (e.g., "en", "fr", "de"). | ||
type: str |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
git+https://github.com/ml6team/fondant@main | ||
pyarrow>=7.0 | ||
gcsfs==2023.4.00 | ||
fasttext-wheel==0.9.2 |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
"""A component that filters text based on the language.""" | ||
import logging | ||
|
||
import fasttext | ||
import pandas as pd | ||
|
||
from fondant.component import PandasTransformComponent | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class LanguageIdentification: | ||
"""A class for language detection using FastText.""" | ||
|
||
def __init__(self, language, model_path: str = "lid.176.ftz"): | ||
""" | ||
Initializes the LanguageDetect class. | ||
Args: | ||
language (str): language to filter on | ||
model_path (str): The path to the FastText language identification model. | ||
""" | ||
pretrained_lang_model_weight_path = model_path | ||
self.language = language | ||
self.model = fasttext.load_model(pretrained_lang_model_weight_path) | ||
|
||
def predict_lang(self, text: str): | ||
""" | ||
Detects the language of a text sequence. | ||
Args: | ||
text (str): The text for language detection. | ||
Returns: | ||
str: The predicted language label. | ||
""" | ||
predictions = self.model.predict(text, k=1) | ||
return predictions[0][0] | ||
|
||
def is_language(self, row): | ||
"""Predict if text of a row is written in the defined language.""" | ||
return self.language in self.predict_lang(row["text"]) | ||
|
||
|
||
class LanguageFilterComponent(PandasTransformComponent): | ||
"""Component that filter columns based on provided language.""" | ||
|
||
def setup(self, *, language): | ||
"""Setup language filter component. | ||
Args: | ||
language: Only keep text passages which are in the provided language. | ||
""" | ||
self.lang_detector = LanguageIdentification(language) | ||
|
||
|
||
def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: | ||
""" | ||
Args: | ||
dataframe: Pandas dataframe. | ||
Returns: | ||
Pandas dataframe | ||
""" | ||
mask = dataframe.apply(self.lang_detector.is_language, axis=1) | ||
|
||
return dataframe[mask] | ||
|
||
|
||
if __name__ == "__main__": | ||
component = LanguageFilterComponent.from_args() | ||
component.run() |
54 changes: 54 additions & 0 deletions
54
components/language_filter/tests/language_filter_component_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
"""Unit test for language filter component.""" | ||
import pandas as pd | ||
|
||
from components.language_filter.src.main import LanguageFilterComponent | ||
from fondant.component_spec import ComponentSpec | ||
|
||
|
||
def test_run_component_test(): | ||
"""Test language filter component.""" | ||
# Given: Dataframe with text in different languages | ||
data = [{"text": "Das hier ist ein Satz in deutscher Sprache"}, | ||
{"text": "This is a sentence in English"}, | ||
{"text": "Dit is een zin in het Nederlands"}] | ||
dataframe = pd.DataFrame(data) | ||
|
||
# When: The language filter component proceed the dataframe | ||
# and filter out all entries which are not written in german | ||
spec = ComponentSpec.from_file("../fondant_component.yaml") | ||
|
||
component = LanguageFilterComponent(spec, input_manifest_path="./dummy_input_manifest.json", | ||
output_manifest_path="./dummy_input_manifest.json", | ||
metadata={}, | ||
user_arguments={"language": "de"}, | ||
) | ||
component.setup(language="de") | ||
dataframe = component.transform(dataframe=dataframe) | ||
|
||
# Then: dataframe only contains one german row | ||
assert len(dataframe) == 1 | ||
assert dataframe.loc[0]["text"] == "Das hier ist ein Satz in deutscher Sprache" | ||
|
||
|
||
def test_run_component_test_filter_out_all(): | ||
"""Test language filter component.""" | ||
# Given: Dataframe with text in different languages | ||
data = [{"text": "Das hier ist ein Satz in deutscher Sprache"}, | ||
{"text": "This is a sentence in English"}, | ||
{"text": "Dit is een zin in het Nederlands"}] | ||
dataframe = pd.DataFrame(data) | ||
|
||
# When: The language filter component proceed the dataframe | ||
# and filter out all entries which are not written in french | ||
spec = ComponentSpec.from_file("../fondant_component.yaml") | ||
|
||
component = LanguageFilterComponent(spec, input_manifest_path="./dummy_input_manifest.json", | ||
output_manifest_path="./dummy_input_manifest.json", | ||
metadata={}, | ||
user_arguments={"language": "fr"}, | ||
) | ||
component.setup() | ||
dataframe = component.transform(dataframe=dataframe) | ||
|
||
# Then: dataframe should contain no rows anymore | ||
assert len(dataframe) == 0 |