From 8249b98726d4d61d8100e4197e7be2f22208ae20 Mon Sep 17 00:00:00 2001 From: Yilun Huang Date: Tue, 21 Nov 2023 10:24:01 +0800 Subject: [PATCH] * Add availability checking for OPs to allow incomplete dependency installation (#82) * * Split requirements into more categories + Add availability checking for each OP when importing them for OPs rely on large or low-platform-compatibility third-party libraries. * - remove duplicate dependencies --- README.md | 34 ++++---- README_ZH.md | 30 ++++--- data_juicer/core/ray_executor.py | 7 +- .../document_minhash_deduplicator.py | 9 ++- .../document_simhash_deduplicator.py | 16 ++-- .../ops/deduplicator/image_deduplicator.py | 25 +++--- data_juicer/ops/filter/alphanumeric_filter.py | 6 ++ .../ops/filter/clip_similarity_filter.py | 16 ++-- .../ops/filter/flagged_words_filter.py | 10 ++- .../ops/filter/language_id_score_filter.py | 8 +- data_juicer/ops/filter/perplexity_filter.py | 11 ++- data_juicer/ops/filter/stopwords_filter.py | 10 ++- data_juicer/ops/filter/token_num_filter.py | 8 +- data_juicer/ops/filter/word_num_filter.py | 10 ++- .../ops/filter/word_repetition_filter.py | 10 ++- data_juicer/ops/load.py | 7 ++ .../ops/mapper/chinese_convert_mapper.py | 10 ++- data_juicer/ops/mapper/clean_html_mapper.py | 9 ++- data_juicer/ops/mapper/fix_unicode_mapper.py | 9 ++- data_juicer/ops/mapper/nlpaug_en_mapper.py | 16 ++-- data_juicer/ops/mapper/nlpcda_zh_mapper.py | 9 ++- ..._words_with_incorrect_substrings_mapper.py | 8 +- .../ops/mapper/sentence_split_mapper.py | 8 +- data_juicer/utils/availability_utils.py | 79 +++++++++++++++++++ environments/dist_requires.txt | 1 + environments/minimal_requires.txt | 6 +- environments/science_requires.txt | 4 - setup.py | 7 +- 28 files changed, 296 insertions(+), 87 deletions(-) create mode 100644 data_juicer/utils/availability_utils.py create mode 100644 environments/dist_requires.txt diff --git a/README.md b/README.md index 18b83b0cc..3374e5ea0 100644 --- a/README.md +++ b/README.md @@ -105,40 +105,46 @@ Table of Contents ### From Source -- Run the following commands to install the latest `data_juicer` version in +- Run the following commands to install the latest basic `data_juicer` version in editable mode: ```shell cd -pip install -v -e .[all] +pip install -v -e . ``` -- Or install optional dependencies: +- Some OPs rely on some other too large or low-platform-compatibility third-party libraries. You can install optional dependencies as needed: + ```shell cd -pip install -v -e . # install a minimal dependencies +pip install -v -e . # install a minimal dependencies, which support the basic functions pip install -v -e .[tools] # install a subset of tools dependencies ``` The dependency options are listed below: -| Tag | Description | -|----------|------------------------------------------------------------------------| -| . | Install minimal dependencies for basic Data-Juicer. | -| .[all] | Install all optional dependencies (all of the following) | -| .[dev] | Install dependencies for developing the package as contributors | -| .[tools] | Install dependencies for dedicated tools, such as quality classifiers. | +| Tag | Description | +|--------------|----------------------------------------------------------------------------------------------| +| `.` or `.[mini]` | Install minimal dependencies for basic Data-Juicer. | +| `.[all]` | Install all optional dependencies (including minimal dependencies and all of the following). | +| `.[sci]` | Install all dependencies for all OPs. | +| `.[dist]` | Install dependencies for distributed data processing. (Experimental) | +| `.[dev]` | Install dependencies for developing the package as contributors. | +| `.[tools]` | Install dependencies for dedicated tools, such as quality classifiers. | ### Using pip -- Run the following command to install the latest `data_juicer` using `pip`: +- Run the following command to install the latest released `data_juicer` using `pip`: ```shell pip install py-data-juicer ``` -- **Note**: only the basic APIs in `data_juicer` and two basic tools - (data [processing](#data-processing) and [analysis](#data-analysis)) are available in this way. If you want customizable - and complete functions, we recommend you install `data_juicer` [from source](#from-source). +- **Note**: + - only the basic APIs in `data_juicer` and two basic tools + (data [processing](#data-processing) and [analysis](#data-analysis)) are available in this way. If you want customizable + and complete functions, we recommend you install `data_juicer` [from source](#from-source). + - The release versions from pypi have a certain lag compared to the latest version from source. + So if you want to follow the latest functions of `data_juicer`, we recommend you install [from source](#from-source). ### Using Docker diff --git a/README_ZH.md b/README_ZH.md index a7a78b455..9d74b5b09 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -93,40 +93,44 @@ Data-Juicer 是一个一站式数据处理系统,旨在为大语言模型 (LLM ### 从源码安装 -* 运行以下命令以安装 `data_juicer` 可编辑模式的最新版本 +* 运行以下命令以安装 `data_juicer` 可编辑模式的最新基础版本 ```shell cd -pip install -v -e .[all] +pip install -v -e . ``` -* 或是安装可选的依赖项: +* 部分算子功能依赖于较大的或者平台兼容性不是很好的第三方库,因此用户可按需额外安装可选的依赖项: ```shell cd -pip install -v -e . # 安装最小依赖 +pip install -v -e . # 安装最小依赖,支持基础功能 pip install -v -e .[tools] # 安装部分工具库的依赖 ``` 依赖选项如下表所示: -| 标签 | 描述 | -|----------|----------------------------------------------| -| . | 安装支持 Data-Juicer 基础功能的最小依赖项 | -| .[all] | 安装所有可选依赖项(即下面所有依赖项) | -| .[dev] | 安装作为贡献者开发 Data-Juicer 所需的依赖项 | -| .[tools] | 安装专用工具库(如质量分类器)所需的依赖项 | +| 标签 | 描述 | +|--------------|------------------------------| +| `.` 或者 `.[mini]` | 安装支持 Data-Juicer 基础功能的最小依赖项 | +| `.[all]` | 安装所有可选依赖项(包括最小依赖项以及下面所有依赖项) | +| `.[sci]` | 安装所有算子的全量依赖 | +| `.[dist]` | 安装以分布式方式进行数据处理的依赖(实验性功能) | +| `.[dev]` | 安装作为贡献者开发 Data-Juicer 所需的依赖项 | +| `.[tools]` | 安装专用工具库(如质量分类器)所需的依赖项 | ### 使用 pip 安装 -* 运行以下命令用 `pip` 安装 `data_juicer` 的最新版本: +* 运行以下命令用 `pip` 安装 `data_juicer` 的最新发布版本: ```shell pip install py-data-juicer ``` -* **注意**:使用这种方法安装时,只有`data_juicer`中的基础的 API 和2个基础工具 - (数据[处理](数据处理)与[分析](数据分析))可以使用。如需更定制化地使用完整功能,建议[从源码进行安装](#从源码安装)。 +* **注意**: + * 使用这种方法安装时,只有`data_juicer`中的基础的 API 和2个基础工具 + (数据[处理](数据处理)与[分析](数据分析))可以使用。如需更定制化地使用完整功能,建议[从源码进行安装](#从源码安装)。 + * pypi 的发布版本较源码的最新版本有一定的滞后性,如需要随时跟进 `data_juicer` 的最新功能支持,建议[从源码进行安装](#从源码安装)。 ### 使用 Docker 安装 diff --git a/data_juicer/core/ray_executor.py b/data_juicer/core/ray_executor.py index 513e16151..e1df77a04 100644 --- a/data_juicer/core/ray_executor.py +++ b/data_juicer/core/ray_executor.py @@ -1,11 +1,14 @@ -import ray -import ray.data as rd from loguru import logger from data_juicer.config import init_configs from data_juicer.ops import Filter, Mapper, load_ops +from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields +with AvailabilityChecking(['ray'], requires_type='dist'): + import ray + import ray.data as rd + class RayExecutor: """ diff --git a/data_juicer/ops/deduplicator/document_minhash_deduplicator.py b/data_juicer/ops/deduplicator/document_minhash_deduplicator.py index df30cc4e7..b420a9956 100644 --- a/data_juicer/ops/deduplicator/document_minhash_deduplicator.py +++ b/data_juicer/ops/deduplicator/document_minhash_deduplicator.py @@ -10,14 +10,19 @@ import regex from jsonargparse.typing import ClosedUnitInterval, PositiveInt from loguru import logger -from scipy.integrate import quad as integrate from tqdm import tqdm +from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import HashKeys from ..base_op import OPERATORS, Deduplicator from ..common.helper_func import UnionFind, split_on_whitespace +OP_NAME = 'document_minhash_deduplicator' + +with AvailabilityChecking(['scipy'], OP_NAME): + from scipy.integrate import quad as integrate + MERSENNE_PRIME = np.uint64((1 << 61) - 1) MAX_HASH = np.uint64((1 << 32) - 1) @@ -89,7 +94,7 @@ def proba(s): return opt -@OPERATORS.register_module('document_minhash_deduplicator') +@OPERATORS.register_module(OP_NAME) class DocumentMinhashDeduplicator(Deduplicator): """ Deduplicator to deduplicate samples at document-level using MinHashLSH. diff --git a/data_juicer/ops/deduplicator/document_simhash_deduplicator.py b/data_juicer/ops/deduplicator/document_simhash_deduplicator.py index 958b7a50e..4e9ad1790 100644 --- a/data_juicer/ops/deduplicator/document_simhash_deduplicator.py +++ b/data_juicer/ops/deduplicator/document_simhash_deduplicator.py @@ -7,15 +7,20 @@ import numpy as np import regex -import simhash from jsonargparse.typing import PositiveInt from loguru import logger +from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import HashKeys from ..base_op import OPERATORS, Deduplicator from ..common.helper_func import split_on_whitespace +OP_NAME = 'document_simhash_deduplicator' + +with AvailabilityChecking(['simhash-py'], OP_NAME): + import simhash + def local_num_differing_bits(hash_a, hash_b): """ @@ -57,10 +62,7 @@ def num_differing_bits_selector(): return simhash.num_differing_bits -num_differing_bits = num_differing_bits_selector() - - -@OPERATORS.register_module('document_simhash_deduplicator') +@OPERATORS.register_module(OP_NAME) class DocumentSimhashDeduplicator(Deduplicator): """Deduplicator to deduplicate samples at document-level using SimHash.""" @@ -112,6 +114,8 @@ def __init__(self, self.num_blocks = num_blocks self.hamming_distance = hamming_distance + self.num_differing_bits = num_differing_bits_selector() + def compute_hash(self, sample): """ Compute simhash values for the sample. @@ -185,7 +189,7 @@ def process(self, dataset, show_num=0): dist = Counter() for x, y in matches: graph[x][y] = graph[y][x] = True - num_diff = num_differing_bits(x, y) + num_diff = self.num_differing_bits(x, y) dist[num_diff] += 1 logger.info(f'Hash diff distribution: {dist}') diff --git a/data_juicer/ops/deduplicator/image_deduplicator.py b/data_juicer/ops/deduplicator/image_deduplicator.py index 7a5275b8a..0553d5c8c 100644 --- a/data_juicer/ops/deduplicator/image_deduplicator.py +++ b/data_juicer/ops/deduplicator/image_deduplicator.py @@ -2,24 +2,29 @@ from typing import Dict, Set import numpy as np -from imagededup.methods import AHash, DHash, PHash, WHash +from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields, HashKeys from data_juicer.utils.mm_utils import load_image from ..base_op import OPERATORS, Deduplicator from ..op_fusion import LOADED_IMAGES -HASH_METHOD = { - 'phash': PHash(), - 'dhash': DHash(), - 'whash': WHash(), - 'ahash': AHash() -} +OP_NAME = 'image_deduplicator' +with AvailabilityChecking(['imagededup'], OP_NAME): + from imagededup.methods import AHash, DHash, PHash, WHash -@OPERATORS.register_module('image_deduplicator') -@LOADED_IMAGES.register_module('image_deduplicator') + HASH_METHOD = { + 'phash': PHash, + 'dhash': DHash, + 'whash': WHash, + 'ahash': AHash + } + + +@OPERATORS.register_module(OP_NAME) +@LOADED_IMAGES.register_module(OP_NAME) class ImageDeduplicator(Deduplicator): """ Deduplicator to deduplicate samples at document-level using exact matching @@ -38,7 +43,7 @@ def __init__(self, method: str = 'phash', *args, **kwargs): if method not in HASH_METHOD.keys(): raise ValueError(f'Keep strategy [{method}] is not supported. ' f'Can only be one of {HASH_METHOD.keys()}.') - self.hasher = HASH_METHOD[method] + self.hasher = HASH_METHOD[method]() def compute_hash(self, sample, context=False): # check if it's computed already diff --git a/data_juicer/ops/filter/alphanumeric_filter.py b/data_juicer/ops/filter/alphanumeric_filter.py index 7dabf5552..68ce2560b 100644 --- a/data_juicer/ops/filter/alphanumeric_filter.py +++ b/data_juicer/ops/filter/alphanumeric_filter.py @@ -2,12 +2,18 @@ from jsonargparse.typing import PositiveFloat +from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields, StatsKeys from data_juicer.utils.model_utils import get_model, prepare_model from ..base_op import OPERATORS, Filter from ..common import get_words_from_document +OP_NAME = 'alphanumeric_filter' + +with AvailabilityChecking(['transformers'], OP_NAME): + import transformers # noqa: F401 + @OPERATORS.register_module('alphanumeric_filter') class AlphanumericFilter(Filter): diff --git a/data_juicer/ops/filter/clip_similarity_filter.py b/data_juicer/ops/filter/clip_similarity_filter.py index 8724beec1..e878999fd 100644 --- a/data_juicer/ops/filter/clip_similarity_filter.py +++ b/data_juicer/ops/filter/clip_similarity_filter.py @@ -1,7 +1,7 @@ import numpy as np -import torch from jsonargparse.typing import ClosedUnitInterval +from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields, StatsKeys from data_juicer.utils.mm_utils import SpecialTokens, load_image from data_juicer.utils.model_utils import get_model, prepare_model @@ -9,12 +9,18 @@ from ..base_op import OPERATORS, Filter from ..op_fusion import LOADED_IMAGES -# avoid hanging when calling clip in multiprocessing -torch.set_num_threads(1) +OP_NAME = 'clip_similarity_filter' +with AvailabilityChecking(['torch'], OP_NAME): + import torch + import transformers # noqa: F401 -@OPERATORS.register_module('clip_similarity_filter') -@LOADED_IMAGES.register_module('clip_similarity_filter') + # avoid hanging when calling clip in multiprocessing + torch.set_num_threads(1) + + +@OPERATORS.register_module(OP_NAME) +@LOADED_IMAGES.register_module(OP_NAME) class ClipSimilarityFilter(Filter): """Filter to keep samples those similarity between image and text within a specific range.""" diff --git a/data_juicer/ops/filter/flagged_words_filter.py b/data_juicer/ops/filter/flagged_words_filter.py index fbc5e4eb8..03c603f1e 100644 --- a/data_juicer/ops/filter/flagged_words_filter.py +++ b/data_juicer/ops/filter/flagged_words_filter.py @@ -4,6 +4,7 @@ from jsonargparse.typing import ClosedUnitInterval, List +from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields, InterVars, StatsKeys from data_juicer.utils.model_utils import get_model, prepare_model @@ -13,9 +14,14 @@ words_refinement) from ..op_fusion import INTER_WORDS +OP_NAME = 'flagged_words_filter' -@OPERATORS.register_module('flagged_words_filter') -@INTER_WORDS.register_module('flagged_words_filter') +with AvailabilityChecking(['sentencepiece'], OP_NAME): + import sentencepiece # noqa: F401 + + +@OPERATORS.register_module(OP_NAME) +@INTER_WORDS.register_module(OP_NAME) class FlaggedWordFilter(Filter): """Filter to keep samples with flagged-word ratio less than a specific max value.""" diff --git a/data_juicer/ops/filter/language_id_score_filter.py b/data_juicer/ops/filter/language_id_score_filter.py index 800f635cf..5c10a0887 100644 --- a/data_juicer/ops/filter/language_id_score_filter.py +++ b/data_juicer/ops/filter/language_id_score_filter.py @@ -1,13 +1,19 @@ from jsonargparse.typing import ClosedUnitInterval from loguru import logger +from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields, StatsKeys from data_juicer.utils.model_utils import get_model, prepare_model from ..base_op import OPERATORS, Filter +OP_NAME = 'language_id_score_filter' -@OPERATORS.register_module('language_id_score_filter') +with AvailabilityChecking(['fasttext-wheel'], OP_NAME): + import fasttext # noqa: F401 + + +@OPERATORS.register_module(OP_NAME) class LanguageIDScoreFilter(Filter): """Filter to keep samples in a specific language with confidence score larger than a specific min value.""" diff --git a/data_juicer/ops/filter/perplexity_filter.py b/data_juicer/ops/filter/perplexity_filter.py index 975279b7c..d125d548c 100644 --- a/data_juicer/ops/filter/perplexity_filter.py +++ b/data_juicer/ops/filter/perplexity_filter.py @@ -4,6 +4,7 @@ from jsonargparse.typing import PositiveFloat +from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields, InterVars, StatsKeys from data_juicer.utils.model_utils import get_model, prepare_model @@ -11,9 +12,15 @@ from ..common import get_words_from_document from ..op_fusion import INTER_WORDS +OP_NAME = 'perplexity_filter' -@OPERATORS.register_module('perplexity_filter') -@INTER_WORDS.register_module('perplexity_filter') +with AvailabilityChecking(['sentencepiece', 'kenlm'], OP_NAME): + import kenlm # noqa: F401 + import sentencepiece # noqa: F401 + + +@OPERATORS.register_module(OP_NAME) +@INTER_WORDS.register_module(OP_NAME) class PerplexityFilter(Filter): """Filter to keep samples with perplexity score less than a specific max value.""" diff --git a/data_juicer/ops/filter/stopwords_filter.py b/data_juicer/ops/filter/stopwords_filter.py index 03c2a5a15..3d73f752f 100644 --- a/data_juicer/ops/filter/stopwords_filter.py +++ b/data_juicer/ops/filter/stopwords_filter.py @@ -5,6 +5,7 @@ from jsonargparse.typing import ClosedUnitInterval, List from data_juicer.utils.asset_utils import ASSET_DIR, load_words_asset +from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields, InterVars, StatsKeys from data_juicer.utils.model_utils import get_model, prepare_model @@ -13,9 +14,14 @@ words_refinement) from ..op_fusion import INTER_WORDS +OP_NAME = 'stopwords_filter' -@OPERATORS.register_module('stopwords_filter') -@INTER_WORDS.register_module('stopwords_filter') +with AvailabilityChecking(['sentencepiece'], OP_NAME): + import sentencepiece # noqa: F401 + + +@OPERATORS.register_module(OP_NAME) +@INTER_WORDS.register_module(OP_NAME) class StopWordsFilter(Filter): """Filter to keep samples with stopword ratio larger than a specific min value.""" diff --git a/data_juicer/ops/filter/token_num_filter.py b/data_juicer/ops/filter/token_num_filter.py index 5bdc58a3d..11ebd0ce5 100644 --- a/data_juicer/ops/filter/token_num_filter.py +++ b/data_juicer/ops/filter/token_num_filter.py @@ -2,14 +2,20 @@ from jsonargparse.typing import PositiveInt +from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields, StatsKeys from data_juicer.utils.model_utils import get_model, prepare_model from ..base_op import OPERATORS, Filter from ..common import get_words_from_document +OP_NAME = 'token_num_filter' -@OPERATORS.register_module('token_num_filter') +with AvailabilityChecking(['transformers'], OP_NAME): + import transformers # noqa: F401 + + +@OPERATORS.register_module(OP_NAME) class TokenNumFilter(Filter): """Filter to keep samples with total token number within a specific range.""" diff --git a/data_juicer/ops/filter/word_num_filter.py b/data_juicer/ops/filter/word_num_filter.py index 98e544b1c..cc740c9d0 100644 --- a/data_juicer/ops/filter/word_num_filter.py +++ b/data_juicer/ops/filter/word_num_filter.py @@ -2,6 +2,7 @@ from jsonargparse.typing import PositiveInt +from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields, InterVars, StatsKeys from data_juicer.utils.model_utils import get_model, prepare_model @@ -10,9 +11,14 @@ words_refinement) from ..op_fusion import INTER_WORDS +OP_NAME = 'words_num_filter' -@OPERATORS.register_module('words_num_filter') -@INTER_WORDS.register_module('words_num_filter') +with AvailabilityChecking(['sentencepiece'], OP_NAME): + import sentencepiece # noqa: F401 + + +@OPERATORS.register_module(OP_NAME) +@INTER_WORDS.register_module(OP_NAME) class WordNumFilter(Filter): """Filter to keep samples with total words number within a specific range.""" diff --git a/data_juicer/ops/filter/word_repetition_filter.py b/data_juicer/ops/filter/word_repetition_filter.py index 3883541e8..126895a1a 100644 --- a/data_juicer/ops/filter/word_repetition_filter.py +++ b/data_juicer/ops/filter/word_repetition_filter.py @@ -4,6 +4,7 @@ from jsonargparse.typing import ClosedUnitInterval, PositiveInt +from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields, InterVars, StatsKeys from data_juicer.utils.model_utils import get_model, prepare_model @@ -12,9 +13,14 @@ words_refinement) from ..op_fusion import INTER_WORDS +OP_NAME = 'word_repetition_filter' -@OPERATORS.register_module('word_repetition_filter') -@INTER_WORDS.register_module('word_repetition_filter') +with AvailabilityChecking(['sentencepiece'], OP_NAME): + import sentencepiece # noqa: F401 + + +@OPERATORS.register_module(OP_NAME) +@INTER_WORDS.register_module(OP_NAME) class WordRepetitionFilter(Filter): """Filter to keep samples with word-level n-gram repetition ratio within a specific range.""" diff --git a/data_juicer/ops/load.py b/data_juicer/ops/load.py index e8d1ed65e..9cee5a120 100644 --- a/data_juicer/ops/load.py +++ b/data_juicer/ops/load.py @@ -1,3 +1,7 @@ +from loguru import logger + +from data_juicer.utils.availability_utils import UNAVAILABLE_OPERATORS + from .base_op import OPERATORS from .op_fusion import fuse_operators @@ -15,6 +19,9 @@ def load_ops(process_list, op_fusion=False): ops = [] for process in process_list: op_name, args = list(process.items())[0] + if op_name in UNAVAILABLE_OPERATORS: + logger.warning(UNAVAILABLE_OPERATORS[op_name].get_warning_msg()) + continue ops.append(OPERATORS.modules[op_name](**args)) # detect filter groups diff --git a/data_juicer/ops/mapper/chinese_convert_mapper.py b/data_juicer/ops/mapper/chinese_convert_mapper.py index 8fc0a41c3..fc321094c 100644 --- a/data_juicer/ops/mapper/chinese_convert_mapper.py +++ b/data_juicer/ops/mapper/chinese_convert_mapper.py @@ -1,13 +1,19 @@ +from data_juicer.utils.availability_utils import AvailabilityChecking + from ..base_op import OPERATORS, Mapper +OP_NAME = 'chinese_convert_mapper' + +with AvailabilityChecking(['opencc'], OP_NAME): + import opencc # noqa: F401 + def prepare_converter(mode): global OPENCC_CONVERTER - import opencc OPENCC_CONVERTER = opencc.OpenCC(mode + '.json') -@OPERATORS.register_module('chinese_convert_mapper') +@OPERATORS.register_module(OP_NAME) class ChineseConvertMapper(Mapper): """Mapper to convert Chinese between Traditional Chinese, Simplified Chinese and Japanese Kanji.""" diff --git a/data_juicer/ops/mapper/clean_html_mapper.py b/data_juicer/ops/mapper/clean_html_mapper.py index 22e092851..dc45754fa 100644 --- a/data_juicer/ops/mapper/clean_html_mapper.py +++ b/data_juicer/ops/mapper/clean_html_mapper.py @@ -2,12 +2,17 @@ # https://github.com/togethercomputer/RedPajama-Data/ # -------------------------------------------------------- -from selectolax.parser import HTMLParser +from data_juicer.utils.availability_utils import AvailabilityChecking from ..base_op import OPERATORS, Mapper +OP_NAME = 'clean_html_mapper' -@OPERATORS.register_module('clean_html_mapper') +with AvailabilityChecking(['selectolax'], OP_NAME): + from selectolax.parser import HTMLParser + + +@OPERATORS.register_module(OP_NAME) class CleanHtmlMapper(Mapper): """Mapper to clean html code in text samples.""" diff --git a/data_juicer/ops/mapper/fix_unicode_mapper.py b/data_juicer/ops/mapper/fix_unicode_mapper.py index 275fbba28..41a686ce7 100644 --- a/data_juicer/ops/mapper/fix_unicode_mapper.py +++ b/data_juicer/ops/mapper/fix_unicode_mapper.py @@ -1,9 +1,14 @@ -import ftfy +from data_juicer.utils.availability_utils import AvailabilityChecking from ..base_op import OPERATORS, Mapper +OP_NAME = 'fix_unicode_mapper' -@OPERATORS.register_module('fix_unicode_mapper') +with AvailabilityChecking(['ftfy'], OP_NAME): + import ftfy + + +@OPERATORS.register_module(OP_NAME) class FixUnicodeMapper(Mapper): """Mapper to fix unicode errors in text samples.""" diff --git a/data_juicer/ops/mapper/nlpaug_en_mapper.py b/data_juicer/ops/mapper/nlpaug_en_mapper.py index ae40b461c..11bc8fa64 100644 --- a/data_juicer/ops/mapper/nlpaug_en_mapper.py +++ b/data_juicer/ops/mapper/nlpaug_en_mapper.py @@ -1,15 +1,21 @@ from copy import deepcopy -import nlpaug.augmenter.char as nac -import nlpaug.augmenter.word as naw -import nlpaug.flow as naf from loguru import logger -from nlpaug.util import Action + +from data_juicer.utils.availability_utils import AvailabilityChecking from ..base_op import OPERATORS, Mapper +OP_NAME = 'nlpaug_en_mapper' + +with AvailabilityChecking(['nlpaug'], OP_NAME): + import nlpaug.augmenter.char as nac + import nlpaug.augmenter.word as naw + import nlpaug.flow as naf + from nlpaug.util import Action + -@OPERATORS.register_module('nlpaug_en_mapper') +@OPERATORS.register_module(OP_NAME) class NlpaugEnMapper(Mapper): """Mapper to simply augment samples in English based on nlpaug library.""" diff --git a/data_juicer/ops/mapper/nlpcda_zh_mapper.py b/data_juicer/ops/mapper/nlpcda_zh_mapper.py index 3f10b2f58..6125842d3 100644 --- a/data_juicer/ops/mapper/nlpcda_zh_mapper.py +++ b/data_juicer/ops/mapper/nlpcda_zh_mapper.py @@ -2,12 +2,18 @@ from loguru import logger +from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.logger_utils import HiddenPrints from ..base_op import OPERATORS, Mapper +OP_NAME = 'nlpcda_zh_mapper' -@OPERATORS.register_module('nlpcda_zh_mapper') +with AvailabilityChecking(['nlpcda'], OP_NAME), HiddenPrints(): + import nlpcda + + +@OPERATORS.register_module(OP_NAME) class NlpcdaZhMapper(Mapper): """Mapper to simply augment samples in Chinese based on nlpcda library.""" @@ -71,7 +77,6 @@ def __init__(self, import warnings warnings.filterwarnings('ignore') - import nlpcda self.aug_pipeline = [] # sample level diff --git a/data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py b/data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py index c6f7c5e43..4835d4de5 100644 --- a/data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py +++ b/data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py @@ -1,5 +1,6 @@ from jsonargparse.typing import List +from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.model_utils import get_model, prepare_model from ..base_op import OPERATORS, Mapper @@ -7,8 +8,13 @@ merge_on_whitespace_tab_newline, split_on_newline_tab_whitespace, strip) +OP_NAME = 'remove_words_with_incorrect_substrings_mapper' -@OPERATORS.register_module('remove_words_with_incorrect_substrings_mapper') +with AvailabilityChecking(['sentencepiece'], OP_NAME): + import sentencepiece # noqa: F401 + + +@OPERATORS.register_module(OP_NAME) class RemoveWordsWithIncorrectSubstringsMapper(Mapper): """Mapper to remove words with incorrect substrings.""" diff --git a/data_juicer/ops/mapper/sentence_split_mapper.py b/data_juicer/ops/mapper/sentence_split_mapper.py index 65a02308e..12e1372c8 100644 --- a/data_juicer/ops/mapper/sentence_split_mapper.py +++ b/data_juicer/ops/mapper/sentence_split_mapper.py @@ -1,10 +1,16 @@ +from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.model_utils import get_model, prepare_model from ..base_op import OPERATORS, Mapper from ..common import get_sentences_from_document +OP_NAME = 'sentence_split_mapper' -@OPERATORS.register_module('sentence_split_mapper') +with AvailabilityChecking(['nltk'], OP_NAME): + import nltk # noqa: F401 + + +@OPERATORS.register_module(OP_NAME) class SentenceSplitMapper(Mapper): """Mapper to split text samples to sentences.""" diff --git a/data_juicer/utils/availability_utils.py b/data_juicer/utils/availability_utils.py new file mode 100644 index 000000000..bc3ca475d --- /dev/null +++ b/data_juicer/utils/availability_utils.py @@ -0,0 +1,79 @@ +from loguru import logger + +UNAVAILABLE_OPERATORS = {} + + +class UnavailableOperator: + + def __init__(self, op_name, requires): + self.op_name = op_name + self.requires = requires + + def get_warning_msg(self): + return f'This OP [{self.op_name}] is unavailable due to importing ' \ + f'third-party requirements of this OP failure: ' \ + f'{self.requires}. You can either run ' \ + f'`pip install -v -e .[sci]` to install all requirements for ' \ + f'all OPs, or run `pip install {" ".join(self.requires)}` ' \ + f'with library version specified by ' \ + f'`environments/science_requires.txt` to install libraries ' \ + f'required by this OP. Data processing will skip this OP later.' + + +class AvailabilityChecking: + """Define a range that checks the availability of third-party libraries for + OPs or other situations. If the checking failed, add corresponding OP to + the unavailable OP + list and skip them when initializing OPs with warnings. + """ + + def __init__( + self, + requires_list, + op_name=None, + requires_type=None, + ): + """ + Initialization method. + + :param requires_list: libraries to import in this range + :param op_name: which op requires these libraries. In default, it's + None, which means the importing process is not in an OP. + """ + self.requires_list = requires_list + self.op_name = op_name + self.requires_type = requires_type + + self.error_msg = f'No module named {self.requires_list}. You might ' \ + f'need to install it by running `pip install ' \ + f'{" ".join(self.requires_list)}`.' + if self.requires_type: + self.error_msg += f' Or install all related requires by running ' \ + f'`pip install -v -e .[{self.requires_type}]`' + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is ModuleNotFoundError: + if self.op_name: + # ModuleNotFoundError for OP: register to UNAVAILABLE_OPERATORS + UNAVAILABLE_OPERATORS[self.op_name] = UnavailableOperator( + op_name=self.op_name, + requires=self.requires_list, + ) + else: + # other situations: print error message and exit + logger.error(f'{exc_type.__name__}: {exc_val}') + logger.error(f'{exc_tb.tb_frame}') + logger.error(self.error_msg) + exit(0) + elif exc_type is None: + # import libs successfully + pass + else: + # other exceptions: raise the exception directly + return False + + # return True to suppress the exception + return True diff --git a/environments/dist_requires.txt b/environments/dist_requires.txt new file mode 100644 index 000000000..e02756318 --- /dev/null +++ b/environments/dist_requires.txt @@ -0,0 +1 @@ +ray diff --git a/environments/minimal_requires.txt b/environments/minimal_requires.txt index 3fda59110..1202407a8 100644 --- a/environments/minimal_requires.txt +++ b/environments/minimal_requires.txt @@ -1,11 +1,14 @@ fsspec==2023.3.0 -pyarrow<=13.0.0 +pyarrow<=12.0.0 pandas==2.0.0 datasets==2.11.0 loguru +tabulate tqdm jsonargparse[signatures] matplotlib +emoji==2.2.0 +regex requests wget zstandard @@ -17,4 +20,3 @@ streamlit spacy==3.5.0 multiprocess==0.70.12 dill==0.3.4 -ray diff --git a/environments/science_requires.txt b/environments/science_requires.txt index bf03d7661..f13d5b740 100644 --- a/environments/science_requires.txt +++ b/environments/science_requires.txt @@ -2,11 +2,7 @@ fasttext-wheel kenlm sentencepiece scipy -tabulate -pandas ftfy -emoji==2.2.0 -regex simhash-py selectolax nlpaug diff --git a/setup.py b/setup.py index 741b2e29a..9ce0369b9 100644 --- a/setup.py +++ b/setup.py @@ -28,11 +28,14 @@ def get_install_requirements(require_f_paths, env_dir='environments'): # allowing selective installment based on users' needs # TODO: The specific taxonomy and dependencies will be determined # after implementing some preliminary operators and detailed discussions -min_requires = get_install_requirements( - ['minimal_requires.txt', 'science_requires.txt']) +min_requires = get_install_requirements(['minimal_requires.txt']) extra_requires = { 'mini': min_requires, + 'sci': + get_install_requirements(['science_requires.txt']), + 'dist': + get_install_requirements(['dist_requires.txt']), 'dev': get_install_requirements(['dev_requires.txt']), 'tools':