From 2689413fb0aff01a833257140460a35e2eeedb96 Mon Sep 17 00:00:00 2001 From: BeachWang <1400012807@pku.edu.cn> Date: Tue, 27 Aug 2024 11:48:28 +0800 Subject: [PATCH 1/4] update spacy to deal conflict with ms-swift (#397) * update_spacy * fix model version * keep model 3.5.0 * update spacy to 3.7.0 & support native tar.gz package * update docker version * update librosa version * update nltk version --------- Co-authored-by: gece.gc --- .github/workflows/docker/docker-compose.yml | 4 +-- data_juicer/utils/model_utils.py | 37 +++++++++++++++++---- environments/minimal_requires.txt | 4 +-- environments/science_requires.txt | 4 +-- 4 files changed, 36 insertions(+), 13 deletions(-) diff --git a/.github/workflows/docker/docker-compose.yml b/.github/workflows/docker/docker-compose.yml index eeba32206..92a5c76c2 100644 --- a/.github/workflows/docker/docker-compose.yml +++ b/.github/workflows/docker/docker-compose.yml @@ -1,7 +1,7 @@ version: '3' services: ray-head: - image: data-juicer-unittest:0.2.1 + image: data-juicer-unittest:0.2.2 pull_policy: never command: ray start --head --dashboard-host 0.0.0.0 --include-dashboard true --block environment: @@ -30,7 +30,7 @@ services: capabilities: [gpu] ray-worker: - image: data-juicer-unittest:0.2.1 + image: data-juicer-unittest:0.2.2 pull_policy: never command: ray start --address=ray-head:6379 --block environment: diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py index fe716333d..f145e4a76 100644 --- a/data_juicer/utils/model_utils.py +++ b/data_juicer/utils/model_utils.py @@ -72,7 +72,7 @@ def check_model(model_name, force=False): ) else: logger.info( - f'Model [{cached_model_path}] not found . Downloading...') + f'Model [{cached_model_path}] not found. Downloading...') try: model_link = os.path.join(MODEL_LINKS, model_name) @@ -406,7 +406,7 @@ def prepare_huggingface_model(pretrained_model_name_or_path, return (model, processor) if return_model else processor -def prepare_spacy_model(lang, name_pattern='{}_core_web_md-3.5.0'): +def prepare_spacy_model(lang, name_pattern='{}_core_web_md-3.7.0'): """ Prepare spacy model for specific language. @@ -419,17 +419,40 @@ def prepare_spacy_model(lang, name_pattern='{}_core_web_md-3.5.0'): assert lang in ['zh', 'en'], 'Diversity only support zh and en' model_name = name_pattern.format(lang) logger.info(f'Loading spacy model [{model_name}]...') - compressed_model = '{}.zip'.format(model_name) + compressed_model = '{}.tar.gz'.format(model_name) # decompress the compressed model if it's not decompressed def decompress_model(compressed_model_path): - decompressed_model_path = compressed_model_path.replace('.zip', '') + if not compressed_model_path.endswith('.tar.gz'): + raise ValueError('Only .tar.gz files are supported') + + decompressed_model_path = compressed_model_path.replace('.tar.gz', '') if os.path.exists(decompressed_model_path) \ and os.path.isdir(decompressed_model_path): return decompressed_model_path - import zipfile - with zipfile.ZipFile(compressed_model_path) as zf: - zf.extractall(DJMC) + + ver_name = os.path.basename(decompressed_model_path) + unver_name = ver_name.rsplit('-', maxsplit=1)[0] + target_dir_in_archive = f'{ver_name}/{unver_name}/{ver_name}/' + + import tarfile + with tarfile.open(compressed_model_path, 'r:gz') as tar: + for member in tar.getmembers(): + if member.name.startswith(target_dir_in_archive): + # relative path without unnecessary directory levels + relative_path = os.path.relpath( + member.name, start=target_dir_in_archive) + target_path = os.path.join(decompressed_model_path, + relative_path) + + if member.isfile(): + # ensure the directory exists + target_directory = os.path.dirname(target_path) + os.makedirs(target_directory, exist_ok=True) + # for files, extract to the specific location + with tar.extractfile(member) as source: + with open(target_path, 'wb') as target: + target.write(source.read()) return decompressed_model_path try: diff --git a/environments/minimal_requires.txt b/environments/minimal_requires.txt index c162fb21d..bd55d2008 100644 --- a/environments/minimal_requires.txt +++ b/environments/minimal_requires.txt @@ -4,7 +4,7 @@ pandas==2.0.3 datasets==2.18.0 av soundfile -librosa +librosa>=0.10 loguru tabulate tqdm @@ -21,7 +21,7 @@ pdfplumber plotly python-docx streamlit -spacy==3.5.0 +spacy==3.7.0 multiprocess==0.70.12 dill==0.3.4 psutil diff --git a/environments/science_requires.txt b/environments/science_requires.txt index e1ab796cd..c1350368b 100644 --- a/environments/science_requires.txt +++ b/environments/science_requires.txt @@ -10,7 +10,7 @@ simhash-pybind selectolax nlpaug nlpcda -nltk +nltk<3.9 transformers>=4.37 transformers_stream_generator einops @@ -18,7 +18,7 @@ accelerate tiktoken opencc==1.1.6 imagededup -spacy-pkuseg==0.0.32 +spacy-pkuseg diffusers simple-aesthetics-predictor scenedetect[opencv] From 22834baea64a4def7348ecc18b1cde4e68ec738d Mon Sep 17 00:00:00 2001 From: Cathy0908 <30484308+Cathy0908@users.noreply.github.com> Date: Thu, 29 Aug 2024 20:28:48 +0800 Subject: [PATCH 2/4] support LLM augmentation ops and support vllm (#338) * support GenerateInstructionMapper * fix EmptyFormatter * fix extract qa * add op optimize_instruction_mapper and support vllm * fix model infer * update param tensor_parallel_size * add param sampling_params for extract QA op * add param sampling_params * add arg for vllm * add _accelerator * add num_proc=1 for data aug op * optimize codes * update mapper init * fix unittest bug * fix unittest --- .gitignore | 1 + configs/config_all.yaml | 30 +- data_juicer/config/config.py | 9 +- data_juicer/core/executor.py | 1 + data_juicer/core/ray_executor.py | 12 +- data_juicer/format/__init__.py | 8 +- data_juicer/format/empty_formatter.py | 84 ++++++ data_juicer/format/load.py | 14 + data_juicer/ops/mapper/__init__.py | 9 +- data_juicer/ops/mapper/extract_qa_mapper.py | 85 +++++- .../ops/mapper/generate_instruction_mapper.py | 284 ++++++++++++++++++ .../ops/mapper/optimize_instruction_mapper.py | 123 ++++++++ data_juicer/utils/model_utils.py | 40 +++ demos/data/demo-dataset-chatml.jsonl | 4 + docs/Operators.md | 5 +- docs/Operators_ZH.md | 5 +- environments/science_requires.txt | 1 + tests/format/test_empty_formatter.py | 44 +++ tests/ops/mapper/test_extract_qa_mapper.py | 20 +- .../test_generate_instruction_mapper.py | 42 +++ .../test_optimize_instruction_mapper.py | 36 +++ 21 files changed, 833 insertions(+), 24 deletions(-) create mode 100644 data_juicer/format/empty_formatter.py create mode 100644 data_juicer/ops/mapper/generate_instruction_mapper.py create mode 100644 data_juicer/ops/mapper/optimize_instruction_mapper.py create mode 100644 demos/data/demo-dataset-chatml.jsonl create mode 100644 tests/format/test_empty_formatter.py create mode 100644 tests/ops/mapper/test_generate_instruction_mapper.py create mode 100644 tests/ops/mapper/test_optimize_instruction_mapper.py diff --git a/.gitignore b/.gitignore index 15c65f412..6ea36c585 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ dist .idea/ wandb/ __pycache__ +.vscode/ diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 855d45731..9cb64fa30 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -62,8 +62,29 @@ process: - clean_copyright_mapper: # remove copyright comments. - expand_macro_mapper: # expand macro definitions in Latex text. - extract_qa_mapper: # mapper to extract question and answer pair from text. - hf_model: 'alibaba-pai/pai-qwen1_5-7b-doc2qa' + hf_model: 'alibaba-pai/pai-qwen1_5-7b-doc2qa' # model name on huggingface to extract question and answer pair. + pattern: null # regular expression pattern to search for within text. + qa_format: 'chatml' # Output format of question and answer pair. + enable_vllm: true # Whether to use vllm for inference acceleration. + tensor_parallel_size: null # It is only valid when enable_vllm is True. The number of GPUs to use for distributed execution with tensor parallelism. + max_model_len: null # It is only valid when enable_vllm is True. Model context length. If unspecified, will be automatically derived from the model config. + max_num_seqs: 256 # It is only valid when enable_vllm is True. Maximum number of sequences to be processed in a single iteration. + sampling_params: {} # Sampling parameters for text generation. e.g {'temperature': 0.9, 'top_p': 0.95} - fix_unicode_mapper: # fix unicode errors in text. + - generate_instruction_mapper: # generate new instruction text data. + hf_model: 'Qwen/Qwen-7B-Chat' # model name on huggingface to generate instruction. + seed_file: 'demos/data/demo-dataset-chatml.jsonl' # Seed file as instruction samples to generate new instructions, chatml format. + instruct_num: 3 # the number of generated samples. + similarity_threshold: 0.7 # the similarity score threshold between the generated samples and the seed samples.Range from 0 to 1. Samples with similarity score less than this threshold will be kept. + prompt_template: null # Prompt template for generate samples. Please make sure the template contains "{augmented_data}", which corresponds to the augmented samples. + qa_pair_template: null # Prompt template for generate question and answer pair description. Please make sure the template contains two "{}" to format question and answer. Default: '【问题】\n{}\n【回答】\n{}\n'. + example_template: null # Prompt template for generate examples. Please make sure the template contains "{qa_pairs}", which corresponds to the question and answer pair description generated by param `qa_pair_template`. + qa_extraction_pattern: null # Regular expression pattern for parsing question and answer from model response. + enable_vllm: true # Whether to use vllm for inference acceleration. + tensor_parallel_size: null # It is only valid when enable_vllm is True. The number of GPUs to use for distributed execution with tensor parallelism. + max_model_len: null # It is only valid when enable_vllm is True. Model context length. If unspecified, will be automatically derived from the model config. + max_num_seqs: 256 # It is only valid when enable_vllm is True. Maximum number of sequences to be processed in a single iteration. + sampling_params: {} # Sampling parameters for text generation. e.g {'temperature': 0.9, 'top_p': 0.95} - image_blur_mapper: # mapper to blur images. p: 0.2 # probability of the image being blured blur_type: 'gaussian' # type of blur kernel, including ['mean', 'box', 'gaussian'] @@ -123,6 +144,13 @@ process: delete_random_char: false # whether to open the augmentation method of deleting random characters from the original texts. e.g. "这里一共有5种不同的数据增强方法" --> "这里一共有5种不同的数据增强" swap_random_char: false # whether to open the augmentation method of swapping random contiguous characters in the original texts. e.g. "这里一共有5种不同的数据增强方法" --> "这里一共有5种不同的数据强增方法" replace_equivalent_num: false # whether to open the augmentation method of replacing random numbers with their equivalent representations in the original texts. **Notice**: Only for numbers for now. e.g. "这里一共有5种不同的数据增强方法" --> "这里一共有伍种不同的数据增强方法" + - optimize_instruction_mapper: # optimize instruction. + hf_model: 'alibaba-pai/Qwen2-7B-Instruct-Refine' # model name on huggingface to optimize instruction + enable_vllm: true # whether to use vllm for inference acceleration. + tensor_parallel_size: null # It is only valid when enable_vllm is True. The number of GPUs to use for distributed execution with tensor parallelism. + max_model_len: null # It is only valid when enable_vllm is True. Model context length. If unspecified, will be automatically derived from the model config. + max_num_seqs: 256 # It is only valid when enable_vllm is True. Maximum number of sequences to be processed in a single iteration. + sampling_params: {} # Sampling parameters for text generation. e.g {'temperature': 0.9, 'top_p': 0.95} - punctuation_normalization_mapper: # normalize unicode punctuations to English punctuations. - remove_bibliography_mapper: # remove bibliography from Latex text. - remove_comments_mapper: # remove comments from Latex text, code, etc. diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py index c8f8f9ded..b6f6efaa8 100644 --- a/data_juicer/config/config.py +++ b/data_juicer/config/config.py @@ -86,6 +86,13 @@ def init_configs(args=None): help='Path to datasets with optional weights(0.0-1.0), 1.0 as ' 'default. Accepted format: dataset1-path dataset2-path ' ' dataset3-path ...') + parser.add_argument( + '--generated_dataset_config', + type=Dict, + default=None, + help='Configuration used to create a dataset. ' + 'The dataset will be created from this configuration if provided. ' + 'It must contain the `type` field to specify the dataset name.') parser.add_argument( '--export_path', type=str, @@ -371,7 +378,7 @@ def init_setup_from_cfg(cfg): redirect=cfg.executor_type == 'default') # check and get dataset dir - if os.path.exists(cfg.dataset_path): + if cfg.get('dataset_path', None) and os.path.exists(cfg.dataset_path): cfg.dataset_path = os.path.abspath(cfg.dataset_path) if os.path.isdir(cfg.dataset_path): cfg.dataset_dir = cfg.dataset_path diff --git a/data_juicer/core/executor.py b/data_juicer/core/executor.py index 5949df76d..87e38dbce 100644 --- a/data_juicer/core/executor.py +++ b/data_juicer/core/executor.py @@ -49,6 +49,7 @@ def __init__(self, cfg=None): # setup formatter logger.info('Setting up data formatter...') self.formatter = load_formatter(self.cfg.dataset_path, + self.cfg.generated_dataset_config, self.cfg.text_keys, self.cfg.suffixes, self.cfg.add_suffix) diff --git a/data_juicer/core/ray_executor.py b/data_juicer/core/ray_executor.py index ae1a51359..a071c2dea 100644 --- a/data_juicer/core/ray_executor.py +++ b/data_juicer/core/ray_executor.py @@ -47,7 +47,17 @@ def run(self, load_data_np=None): """ # 1. load data logger.info('Loading dataset with Ray...') - dataset = rd.read_json(self.cfg.dataset_path) + + if self.cfg.get('generated_dataset_config', None): + generated_dataset_config = self.cfg.generated_dataset_config + assert isinstance(generated_dataset_config, + dict) and 'type' in generated_dataset_config + args = generated_dataset_config.copy() + obj_name = args.pop('type') + from data_juicer.format.formatter import FORMATTERS + dataset = FORMATTERS.modules[obj_name](**args).load_dataset() + else: + dataset = rd.read_json(self.cfg.dataset_path) # convert all the path in dataset to absolute path dataset = RayDataset(dataset, self.cfg.dataset_path, self.cfg) diff --git a/data_juicer/format/__init__.py b/data_juicer/format/__init__.py index e25ec9921..a0473fca3 100644 --- a/data_juicer/format/__init__.py +++ b/data_juicer/format/__init__.py @@ -1,6 +1,8 @@ -from . import (csv_formatter, json_formatter, mixture_formatter, - parquet_formatter, text_formatter, tsv_formatter) +from . import (csv_formatter, empty_formatter, json_formatter, + mixture_formatter, parquet_formatter, text_formatter, + tsv_formatter) from .csv_formatter import CsvFormatter +from .empty_formatter import EmptyFormatter, RayEmptyFormatter from .formatter import LocalFormatter, RemoteFormatter from .json_formatter import JsonFormatter from .load import load_formatter @@ -12,5 +14,5 @@ __all__ = [ 'load_formatter', 'JsonFormatter', 'LocalFormatter', 'RemoteFormatter', 'TextFormatter', 'ParquetFormatter', 'CsvFormatter', 'TsvFormatter', - 'MixtureFormatter' + 'MixtureFormatter', 'EmptyFormatter', 'RayEmptyFormatter' ] diff --git a/data_juicer/format/empty_formatter.py b/data_juicer/format/empty_formatter.py new file mode 100644 index 000000000..527cd9851 --- /dev/null +++ b/data_juicer/format/empty_formatter.py @@ -0,0 +1,84 @@ +from typing import List + +import pandas as pd +import ray +from datasets import Dataset, Features, Value + +from .formatter import FORMATTERS, BaseFormatter + + +@FORMATTERS.register_module() +class EmptyFormatter(BaseFormatter): + """ + The class is used to create empty data. + """ + SUFFIXES = [] + + def __init__(self, length, feature_keys: List[str] = [], *args, **kwargs): + """ + Initialization method. + + :param length: The empty dataset length. + :param feature_keys: feature key name list. + """ + self.length = length + self.feature_keys = feature_keys + if isinstance(self.feature_keys, str): + self.feature_keys = [self.feature_keys] + + @property + def null_value(self): + return None + + def load_dataset(self, *args, **kwargs): + data_dict = {} + features = Features() + + for key in self.feature_keys: + features.update({key: Value('string')}) + data_dict.update( + {key: [self.null_value for _ in range(self.length)]}) + + empty_dataset = Dataset.from_dict(data_dict, features=features) + + from data_juicer.core.data import NestedDataset + empty_dataset = NestedDataset(empty_dataset) + + return empty_dataset + + +@FORMATTERS.register_module() +class RayEmptyFormatter(BaseFormatter): + """ + The class is used to create empty data for ray. + """ + SUFFIXES = [] + + def __init__(self, length, feature_keys: List[str] = [], *args, **kwargs): + """ + Initialization method. + + :param length: The empty dataset length. + :param feature_keys: feature key name list. + """ + self.length = length + self.feature_keys = feature_keys + if isinstance(self.feature_keys, str): + self.feature_keys = [self.feature_keys] + + @property + def null_value(self): + return {} + + def load_dataset(self, *args, **kwargs): + if len(self.feature_keys): + df = pd.DataFrame({ + col: [self.null_value for _ in range(self.length)] + for col in self.feature_keys + }) + else: + df = pd.DataFrame([self.null_value for _ in range(self.length)]) + + empty_dataset = ray.data.from_pandas(df) + + return empty_dataset diff --git a/data_juicer/format/load.py b/data_juicer/format/load.py index e2bc148c4..3a65817be 100644 --- a/data_juicer/format/load.py +++ b/data_juicer/format/load.py @@ -3,6 +3,7 @@ def load_formatter(dataset_path, + generated_dataset_config=None, text_keys=None, suffixes=[], add_suffix=False, @@ -12,6 +13,9 @@ def load_formatter(dataset_path, weight(default 1.0) according to their formats. :param dataset_path: path to a dataset file or a dataset directory + :param generated_dataset_config: Configuration used to create a dataset. + The dataset will be created from this configuration if provided. + It must contain the `type` field to specify the dataset name. :param text_keys: key names of field that stores sample text. Default: None :param suffixes: files with specified suffixes to be processed. @@ -19,6 +23,16 @@ def load_formatter(dataset_path, info :return: a dataset formatter. """ + if generated_dataset_config: + assert isinstance(generated_dataset_config, + dict) and 'type' in generated_dataset_config + args = generated_dataset_config.copy() + obj_name = args.pop('type') + args.update(kwargs) + + from .formatter import FORMATTERS + return FORMATTERS.modules[obj_name](**args) + formatter = MixtureFormatter(dataset_path=dataset_path, text_keys=text_keys, suffixes=suffixes, diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index 5213498e9..d0e32825c 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -2,10 +2,11 @@ from . import (audio_ffmpeg_wrapped_mapper, chinese_convert_mapper, clean_copyright_mapper, clean_email_mapper, clean_html_mapper, clean_ip_mapper, clean_links_mapper, expand_macro_mapper, - extract_qa_mapper, fix_unicode_mapper, image_blur_mapper, + extract_qa_mapper, fix_unicode_mapper, + generate_instruction_mapper, image_blur_mapper, image_captioning_from_gpt4v_mapper, image_captioning_mapper, image_diffusion_mapper, image_face_blur_mapper, - nlpaug_en_mapper, nlpcda_zh_mapper, + nlpaug_en_mapper, nlpcda_zh_mapper, optimize_instruction_mapper, punctuation_normalization_mapper, remove_bibliography_mapper, remove_comments_mapper, remove_header_mapper, remove_long_words_mapper, remove_non_chinese_character_mapper, @@ -34,6 +35,7 @@ from .expand_macro_mapper import ExpandMacroMapper from .extract_qa_mapper import ExtractQAMapper from .fix_unicode_mapper import FixUnicodeMapper +from .generate_instruction_mapper import GenerateInstructionMapper from .image_blur_mapper import ImageBlurMapper from .image_captioning_from_gpt4v_mapper import ImageCaptioningFromGPT4VMapper from .image_captioning_mapper import ImageCaptioningMapper @@ -41,6 +43,7 @@ from .image_face_blur_mapper import ImageFaceBlurMapper from .nlpaug_en_mapper import NlpaugEnMapper from .nlpcda_zh_mapper import NlpcdaZhMapper +from .optimize_instruction_mapper import OptimizeInstructionMapper from .punctuation_normalization_mapper import PunctuationNormalizationMapper from .remove_bibliography_mapper import RemoveBibliographyMapper from .remove_comments_mapper import RemoveCommentsMapper @@ -92,6 +95,7 @@ 'VideoFFmpegWrappedMapper', 'ChineseConvertMapper', 'NlpcdaZhMapper', + 'OptimizeInstructionMapper', 'ImageBlurMapper', 'CleanCopyrightMapper', 'RemoveNonChineseCharacterlMapper', @@ -108,6 +112,7 @@ 'RemoveWordsWithIncorrectSubstringsMapper', 'VideoCaptioningFromVideoMapper', 'VideoCaptioningFromSummarizerMapper', + 'GenerateInstructionMapper', 'FixUnicodeMapper', 'NlpaugEnMapper', 'VideoCaptioningFromFramesMapper', diff --git a/data_juicer/ops/mapper/extract_qa_mapper.py b/data_juicer/ops/mapper/extract_qa_mapper.py index 52d117cea..31767543f 100644 --- a/data_juicer/ops/mapper/extract_qa_mapper.py +++ b/data_juicer/ops/mapper/extract_qa_mapper.py @@ -1,12 +1,27 @@ import json -import logging import re +from typing import Dict -from data_juicer.ops.base_op import OPERATORS, Mapper +from loguru import logger + +from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.model_utils import get_model, prepare_model +OP_NAME = 'extract_qa_mapper' + +with AvailabilityChecking(['torch', 'transformers', 'vllm'], OP_NAME): + import torch + import transformers # noqa: F401 + import vllm # noqa: F401 + + # avoid hanging when calling model in multiprocessing + torch.set_num_threads(1) + -@OPERATORS.register_module('extract_qa_mapper') +# TODO: Extend LLM-based OPs into API-based implementation. +@UNFORKABLE.register_module(OP_NAME) +@OPERATORS.register_module(OP_NAME) class ExtractQAMapper(Mapper): """ Mapper to extract question and answer pair from text samples. @@ -23,20 +38,36 @@ class ExtractQAMapper(Mapper): """ _accelerator = 'cuda' - _batched_op = True def __init__(self, hf_model: str = 'alibaba-pai/pai-qwen1_5-7b-doc2qa', trust_remote_code=False, pattern: str = None, qa_format: str = 'chatml', + enable_vllm: bool = True, + tensor_parallel_size: int = None, + max_model_len: int = None, + max_num_seqs: int = 256, + sampling_params: Dict = {}, *args, **kwargs): """ Initialization method. :param hf_model: Hugginface model id. + :param trust_remote_code: passed to transformers :param pattern: regular expression pattern to search for within text. :param qa_format: Output format of question and answer pair. + :param enable_vllm: Whether to use vllm for inference acceleration. + :param tensor_parallel_size: It is only valid when enable_vllm is True. + The number of GPUs to use for distributed execution with tensor + parallelism. + :param max_model_len: It is only valid when enable_vllm is True. + Model context length. If unspecified, will be automatically + derived from the model config. + :param max_num_seqs: It is only valid when enable_vllm is True. + Maximum number of sequences to be processed in a single iteration. + :param sampling_params: Sampling parameters for text generation. + e.g {'temperature': 0.9, 'top_p': 0.95} :param args: extra args :param kwargs: extra args @@ -55,6 +86,7 @@ def __init__(self, """ super().__init__(*args, **kwargs) + self.num_proc = 1 if pattern is None: self.pattern = r'Human: (.*?)\nAssistant: (.*?)(?=\nHuman|$)' @@ -62,9 +94,31 @@ def __init__(self, self.pattern = pattern self.qa_format = qa_format - self.model_key = prepare_model(model_type='huggingface', - pretrained_model_name_or_path=hf_model, - trust_remote_code=trust_remote_code) + self.enable_vllm = enable_vllm + + if enable_vllm: + import torch + from vllm import SamplingParams + + assert torch.cuda.device_count() >= 1, 'must be executed in CUDA' + if not tensor_parallel_size: + tensor_parallel_size = torch.cuda.device_count() + logger.info(f'Set tensor_parallel_size to \ + {tensor_parallel_size} for vllm.') + self.model_key = prepare_model( + model_type='vllm', + pretrained_model_name_or_path=hf_model, + trust_remote_code=trust_remote_code, + tensor_parallel_size=tensor_parallel_size, + max_model_len=max_model_len, + max_num_seqs=max_num_seqs) + self.sampling_params = SamplingParams(**sampling_params) + else: + self.model_key = prepare_model( + model_type='huggingface', + pretrained_model_name_or_path=hf_model, + trust_remote_code=trust_remote_code) + self.sampling_params = sampling_params def _extract_qa(self, output): """Extract qestion and answer pair from model output response.""" @@ -82,14 +136,21 @@ def _extract_qa(self, output): def process(self, sample, rank=None): model, processor = get_model(self.model_key, rank, self.use_cuda()) - inputs = processor(sample[self.text_key], - return_tensors='pt').to(model.device) - response = model.generate(**inputs) - output = processor.decode(response.cpu()[0], skip_special_tokens=True) + if self.enable_vllm: + response = model.generate([sample[self.text_key]], + self.sampling_params) + output = response[0].outputs[0].text + else: + inputs = processor(sample[self.text_key], + return_tensors='pt').to(model.device) + response = model.generate(**inputs, **self.sampling_params) + output = processor.decode(response.cpu()[0], + skip_special_tokens=True) + qa_list = self._extract_qa(output) if not len(qa_list): - logging.info( + logger.info( 'No question and answer data was extracted from this sample!') dialogue_data = [] diff --git a/data_juicer/ops/mapper/generate_instruction_mapper.py b/data_juicer/ops/mapper/generate_instruction_mapper.py new file mode 100644 index 000000000..92269554d --- /dev/null +++ b/data_juicer/ops/mapper/generate_instruction_mapper.py @@ -0,0 +1,284 @@ +import json +import random +import re +from typing import Dict + +from loguru import logger + +from data_juicer.utils.availability_utils import AvailabilityChecking +from data_juicer.utils.model_utils import get_model, prepare_model + +from ..base_op import OPERATORS, UNFORKABLE, Mapper + +DEFAULT_PROMPT_TEMPLATE = """ +请你仔细观察多个示例数据的输入和输出,按照你的理解,总结出相应规矩,然后写出一个新的【问题】和【回答】。注意,新生成的【问题】和【回答】需要满足如下要求: +1. 生成的【问题】和【回答】不能与输入的【问题】和【回答】一致,但是需要保持格式相同。 +2. 生成的【问题】不一定要局限于输入【问题】的话题或领域,生成的【回答】需要正确回答生成的【问题】。 +3. 提供的【问题】和【回答】可能是多轮对话,生成的【问题】和【回答】也可以是多轮,但是需要保持格式相同。 +4. 生成的【问题】和【回答】必须成对出现,而且【问题】需要在【回答】之前。 +{augmented_data} +""" +QA_EXTRACTION_PATTERN = r'【问题】\s*(.*?)\s*【回答】\s*(.*?)\s*(?=【问题】|$)' +EXAMPLE_TEMPLATE = '\n如下是一条示例数据:\n\n{qa_pairs}' +QA_PAIR_TEMPLATE = '【问题】\n{}\n【回答】\n{}\n' + +OP_NAME = 'generate_instruction_mapper' + +with AvailabilityChecking(['torch', 'transformers', 'vllm'], OP_NAME): + import torch + import transformers # noqa: F401 + import vllm # noqa: F401 + + # avoid hanging when calling model in multiprocessing + torch.set_num_threads(1) + + +# TODO: Extend LLM-based OPs into API-based implementation. +@UNFORKABLE.register_module(OP_NAME) +@OPERATORS.register_module(OP_NAME) +class GenerateInstructionMapper(Mapper): + """Mapper to generate new instruction text data. + You should configure an empty dataset in your yaml config file: + ``` + generated_dataset_config: + type: 'EmptyFormatter' # use `RayEmptyFormatter` when enable ray + length: ${The number of generated samples} + feature_keys: ${text key} + ``` + The number of samples generated is determined by + the length of the empty dataset. + """ + _accelerator = 'cuda' + + def __init__(self, + hf_model, + seed_file, + instruct_num, + trust_remote_code: bool = False, + similarity_threshold: float = 0.7, + prompt_template: str = None, + qa_pair_template: str = None, + example_template: str = None, + qa_extraction_pattern: str = None, + enable_vllm: bool = True, + tensor_parallel_size: int = None, + max_model_len: int = None, + max_num_seqs: int = 256, + sampling_params: Dict = {}, + *args, + **kwargs): + """ + Initialization method. + + :param hf_model: Hugginface model id. + :param seed_file: Seed file path, chatml format. + :param instruct_num: The number of instruction samples. + Randomly select N samples from "seed_file" and + put them into prompt as instruction samples. + :param trust_remote_code: passed to transformers + :param similarity_threshold: The similarity score threshold + between the generated samples and the seed samples. + Range from 0 to 1. Samples with similarity score less than + this threshold will be kept. + :param prompt_template: Prompt template for generate samples. + Please make sure the template contains "{augmented_data}", + which corresponds to the augmented samples. + :param qa_pair_template: Prompt template for generate question + and answer pair description. Please make sure the template + contains two "{}" to format question and answer. + Default: '【问题】\n{}\n【回答】\n{}\n'. + :param example_template: Prompt template for generate examples. + Please make sure the template contains "{qa_pairs}", which + corresponds to the question and answer pair description + generated by param `qa_pair_template`. + Default: '\n如下是一条示例数据:\n\n{qa_pairs}' + :param qa_extraction_pattern: Regular expression pattern for parsing + question and answer from model response. + :param enable_vllm: Whether to use vllm for inference acceleration. + :param tensor_parallel_size: It is only valid when enable_vllm is True. + The number of GPUs to use for distributed execution with tensor + parallelism. + :param max_model_len: It is only valid when enable_vllm is True. + Model context length. If unspecified, will be automatically + derived from the model config. + :param max_num_seqs: It is only valid when enable_vllm is True. + Maximum number of sequences to be processed in a single iteration. + :param sampling_params: Sampling parameters for text generation. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.num_proc = 1 + + self.instruct_num = instruct_num + self.similarity_threshold = similarity_threshold + self.similarity_type = 'rouge_l' + + if prompt_template is None: + prompt_template = DEFAULT_PROMPT_TEMPLATE + if qa_pair_template is None: + qa_pair_template = QA_PAIR_TEMPLATE + if example_template is None: + example_template = EXAMPLE_TEMPLATE + if qa_extraction_pattern is None: + qa_extraction_pattern = QA_EXTRACTION_PATTERN + + self.prompt_template = prompt_template + self.qa_pair_template = qa_pair_template + self.example_template = example_template + self.qa_extraction_pattern = qa_extraction_pattern + + self.enable_vllm = enable_vllm + + if enable_vllm: + import torch + from vllm import SamplingParams + + assert torch.cuda.device_count() >= 1, 'must be executed in CUDA' + if not tensor_parallel_size: + tensor_parallel_size = torch.cuda.device_count() + logger.info(f'Set tensor_parallel_size to \ + {tensor_parallel_size} for vllm.') + self.model_key = prepare_model( + model_type='vllm', + pretrained_model_name_or_path=hf_model, + trust_remote_code=trust_remote_code, + tensor_parallel_size=tensor_parallel_size, + max_model_len=max_model_len, + max_num_seqs=max_num_seqs) + self.sampling_params = SamplingParams(**sampling_params) + else: + self.model_key = prepare_model( + model_type='huggingface', + pretrained_model_name_or_path=hf_model, + trust_remote_code=trust_remote_code) + self.sampling_params = sampling_params + + self.seed_qa_samples = self.load_seed_qa_samples(seed_file) + + if len(self.seed_qa_samples) == 0: + raise ValueError('No QA data was parsed from the seed file!') + + self.reference_samples = [ + '\n'.join(['\n'.join(qa_pair) for qa_pair in qa_pairs]) + '\n' + for qa_pairs in self.seed_qa_samples + ] + + def load_seed_qa_samples(self, seed_file): + """Load QA pairs from chatml format file.""" + qa_samples = [] + with open(seed_file) as f: + lines = f.readlines() + for line in lines: + line = line.strip() + qa_pairs = self.parse_chatml_str(line) + if len(qa_pairs) > 0: + qa_samples.append(qa_pairs) + + return qa_samples + + def build_prompt(self, qa_samples, prompt_template): + + def format_qa_pairs(qa_pairs): + return ''.join([ + self.qa_pair_template.format(q, a) for q, a in qa_pairs + if q and a + ]) + + body_fragments = [ + self.example_template.format(qa_pairs=format_qa_pairs(qa_pairs)) + for qa_pairs in qa_samples + ] + + body = ''.join(body_fragments) + + return prompt_template.format(augmented_data=body) + + def parse_chatml_str(self, input_str): + user_input = None + assistant_output = None + qa_pairs = [] + data = json.loads(input_str) + for message in data['messages']: + role = message['role'] + content = message['content'] + if role == 'user': + user_input = content + elif role == 'assistant': + assistant_output = content + qa_pairs.append((user_input, assistant_output)) + return qa_pairs + + def parse_response(self, response_str): + pattern = self.qa_extraction_pattern + matches = re.findall(pattern, response_str, re.DOTALL) + response_str = '' + out_qa_pairs = [] + for i, match in enumerate(matches): + question, answer = match + question = question.strip() + answer = answer.strip() + out_qa_pairs.append((question, answer)) + response_str += question + '\n' + answer + '\n' + + if len(out_qa_pairs) == 0: + logger.error('Parse model response error! ' + 'No data generated for the current response!') + + return out_qa_pairs, response_str + + def max_rouge_l_score(self, reference, candidates): + from rouge import Rouge + + rouge = Rouge() + max_score = 0.0 + for candidate in candidates: + scores = rouge.get_scores(candidate, reference) + rouge_l_score = scores[0]['rouge-l']['f'] + if rouge_l_score > max_score: + max_score = rouge_l_score + return max_score + + def process(self, sample=None, rank=None): + model, processor = get_model(self.model_key, rank=rank) + + random_qa_samples = random.sample(self.seed_qa_samples, + self.instruct_num) + input_prompt = self.build_prompt(random_qa_samples, + self.prompt_template) + if self.enable_vllm: + response = model.generate([input_prompt], self.sampling_params) + response_str = response[0].outputs[0].text + else: + inputs = processor(input_prompt, + return_tensors='pt').to(model.device) + output_ids = model.generate(**inputs, **self.sampling_params) + # remove the input prompt from the output + output_ids = output_ids[:, inputs.data['input_ids'].shape[1]:] + response_str = processor.decode(output_ids.cpu()[0], + skip_special_tokens=True) + message_list = [] + out_qa_pairs, response_str = self.parse_response(response_str) + + if not response_str: + return {self.text_key: json.dumps({'messages': message_list})} + + if self.similarity_type == 'rouge_l': + sim_score = self.max_rouge_l_score(response_str, + self.reference_samples) + else: + raise ValueError( + f'Not support similarity type "{self.similarity_type}"!') + + if sim_score <= self.similarity_threshold: + for question, answer in out_qa_pairs: + message_list.append({'role': 'user', 'content': question}) + message_list.append({'role': 'assistant', 'content': answer}) + else: + logger.info('Filter this generated sample due to similarity.') + + return { + self.text_key: + json.dumps({'messages': message_list}, ensure_ascii=False) + } diff --git a/data_juicer/ops/mapper/optimize_instruction_mapper.py b/data_juicer/ops/mapper/optimize_instruction_mapper.py new file mode 100644 index 000000000..32785dc27 --- /dev/null +++ b/data_juicer/ops/mapper/optimize_instruction_mapper.py @@ -0,0 +1,123 @@ +from typing import Dict + +from loguru import logger + +from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.utils.availability_utils import AvailabilityChecking +from data_juicer.utils.model_utils import get_model, prepare_model + +DEFAULT_SYSTEM_PROMPT = '请优化这个指令,将其修改为一个更详细具体的指令。' + +OP_NAME = 'optimize_instruction_mapper' + +with AvailabilityChecking(['torch', 'transformers', 'vllm'], OP_NAME): + import torch + import transformers # noqa: F401 + import vllm # noqa: F401 + + # avoid hanging when calling model in multiprocessing + torch.set_num_threads(1) + + +# TODO: Extend LLM-based OPs into API-based implementation. +@UNFORKABLE.register_module(OP_NAME) +@OPERATORS.register_module(OP_NAME) +class OptimizeInstructionMapper(Mapper): + """Mapper to optimize instruction. + Recommended model list: [ + alibaba-pai/Qwen2-1.5B-Instruct-Refine + alibaba-pai/Qwen2-7B-Instruct-Refine + ] + """ + _accelerator = 'cuda' + + def __init__(self, + hf_model: str = 'alibaba-pai/Qwen2-7B-Instruct-Refine', + trust_remote_code: bool = False, + system_prompt: str = None, + enable_vllm: bool = True, + tensor_parallel_size: int = None, + max_model_len: int = None, + max_num_seqs: int = 256, + sampling_params: Dict = {}, + *args, + **kwargs): + """ + Initialization method. + :param hf_model: Hugginface model id. + :param trust_remote_code: passed to transformers + :param system_prompt: System prompt for optimize samples. + :param enable_vllm: Whether to use vllm for inference acceleration. + :param tensor_parallel_size: It is only valid when enable_vllm is True. + The number of GPUs to use for distributed execution with tensor + parallelism. + :param max_model_len: It is only valid when enable_vllm is True. + Model context length. If unspecified, will be automatically + derived from the model config. + :param max_num_seqs: It is only valid when enable_vllm is True. + Maximum number of sequences to be processed in a single iteration. + :param sampling_params: Sampling parameters for text generation. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.num_proc = 1 + + if system_prompt is None: + system_prompt = DEFAULT_SYSTEM_PROMPT + self.system_prompt = system_prompt + self.enable_vllm = enable_vllm + + if enable_vllm: + import torch + from vllm import SamplingParams + + assert torch.cuda.device_count() >= 1, 'must be executed in CUDA' + if not tensor_parallel_size: + tensor_parallel_size = torch.cuda.device_count() + logger.info(f'Set tensor_parallel_size to \ + {tensor_parallel_size} for vllm.') + self.model_key = prepare_model( + model_type='vllm', + pretrained_model_name_or_path=hf_model, + trust_remote_code=trust_remote_code, + tensor_parallel_size=tensor_parallel_size, + max_model_len=max_model_len, + max_num_seqs=max_num_seqs) + self.sampling_params = SamplingParams(**sampling_params) + else: + self.model_key = prepare_model( + model_type='huggingface', + pretrained_model_name_or_path=hf_model, + trust_remote_code=trust_remote_code) + self.sampling_params = sampling_params + + def process(self, sample=None, rank=None): + model, processor = get_model(self.model_key, rank=rank) + + messages = [{ + 'role': 'system', + 'content': self.system_prompt + }, { + 'role': 'user', + 'content': sample[self.text_key] + }] + input_prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True) + + if self.enable_vllm: + response = model.generate([input_prompt], self.sampling_params) + output = response[0].outputs[0].text + else: + inputs = processor(input_prompt, + return_tensors='pt').to(model.device) + response = model.generate(**inputs, + eos_token_id=processor.eos_token_id, + **self.sampling_params) + output = processor.decode(response.cpu()[0], + skip_special_tokens=True) + + sample[self.text_key] = output + + return sample diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py index f145e4a76..a36d960f5 100644 --- a/data_juicer/utils/model_utils.py +++ b/data_juicer/utils/model_utils.py @@ -406,6 +406,45 @@ def prepare_huggingface_model(pretrained_model_name_or_path, return (model, processor) if return_model else processor +def prepare_vllm_model(pretrained_model_name_or_path, + return_model=True, + trust_remote_code=False, + tensor_parallel_size=1, + max_model_len=None, + max_num_seqs=256): + """ + Prepare and load a HuggingFace model with the correspoding processor. + + :param pretrained_model_name_or_path: model name or path + :param return_model: return model or not + :param trust_remote_code: passed to transformers + :param tensor_parallel_size: The number of GPUs to use for distributed + execution with tensor parallelism. + :param max_model_len: Model context length. If unspecified, will + be automatically derived from the model config. + :param max_num_seqs: Maximum number of sequences to be processed in a + single iteration. + :return: a tuple (model, input processor) if `return_model` is True; + otherwise, only the processor is returned. + """ + from transformers import AutoProcessor + from vllm import LLM as vLLM + + processor = AutoProcessor.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=trust_remote_code) + + if return_model: + import torch + model = vLLM(model=pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + dtype=torch.float16, + tensor_parallel_size=tensor_parallel_size, + max_model_len=max_model_len, + max_num_seqs=max_num_seqs) + + return (model, processor) if return_model else processor + + def prepare_spacy_model(lang, name_pattern='{}_core_web_md-3.7.0'): """ Prepare spacy model for specific language. @@ -570,6 +609,7 @@ def prepare_opencv_classifier(model_path): 'diffusion': prepare_diffusion_model, 'video_blip': prepare_video_blip_model, 'recognizeAnything': prepare_recognizeAnything_model, + 'vllm': prepare_vllm_model, 'opencv_classifier': prepare_opencv_classifier, } diff --git a/demos/data/demo-dataset-chatml.jsonl b/demos/data/demo-dataset-chatml.jsonl new file mode 100644 index 000000000..46b837934 --- /dev/null +++ b/demos/data/demo-dataset-chatml.jsonl @@ -0,0 +1,4 @@ +{"messages": [{"role": "system", "content": "You are a helpful assistant"}, {"role": "user", "content": "谁在文艺复兴时期绘制人体?"}, {"role": "assistant", "content": "文艺复兴时期是一个关于艺术、文化和学术的复兴运动,在这个时期,许多艺术家都绘制了人体。"},{"role": "user", "content": "那雕塑方面如何呢?"}, {"role": "assistant", "content": "文艺复兴时期的雕塑也非常有名,几位世界级的雕塑大师都出自于这个时期。"}]} +{"messages":[{"content":"You are a helpful assistant","role":"system"},{"content":"什么时期的音乐家开始广泛使用交响乐团?","role":"user"},{"content":"浪漫主义时期,音乐家们开始广泛使用和扩展交响乐团,创作出规模宏大、情感丰富的交响乐作品。","role":"assistant"}]} +{"messages":[{"content":"You are a helpful assistant","role":"system"},{"content":"哪个物理定律描述了物体在不受外力作用时保持静止或匀速直线运动的状态?","role":"user"},{"content":"牛顿第一定律,也称为惯性定律,描述了物体在不受外力作用时保持静止状态或匀速直线运动的状态。","role":"assistant"}]} +{"messages":[{"content":"You are a helpful assistant","role":"system"},{"content":"哪种文学流派强调通过象征和暗喻探索潜意识思维?","role":"user"},{"content":"现代主义文学流派强调通过象征、暗喻以及非线性叙述等手法,深入探索人物的内心世界与潜意识思维。","role":"assistant"}]} \ No newline at end of file diff --git a/docs/Operators.md b/docs/Operators.md index a35210161..144550790 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -11,7 +11,7 @@ The operators in Data-Juicer are categorized into 5 types. | Type | Number | Description | |-----------------------------------|:------:|-------------------------------------------------| | [ Formatter ]( #formatter ) | 7 | Discovers, loads, and canonicalizes source data | -| [ Mapper ]( #mapper ) | 43 | Edits and transforms samples | +| [ Mapper ]( #mapper ) | 46 | Edits and transforms samples | | [ Filter ]( #filter ) | 41 | Filters out low-quality samples | | [ Deduplicator ]( #deduplicator ) | 5 | Detects and removes duplicate samples | | [ Selector ]( #selector ) | 4 | Selects top samples based on ranking | @@ -58,7 +58,9 @@ All the specific operators are listed below, each featured with several capabili | clean_ip_mapper | General | en, zh | Removes IP addresses | | clean_links_mapper | General, Code | en, zh | Removes links, such as those starting with http or ftp | | expand_macro_mapper | LaTeX | en, zh | Expands macros usually defined at the top of TeX documents | +| extract_qa_mapper | General | en, zh | Extract question and answer pair from text samples. | | fix_unicode_mapper | General | en, zh | Fixes broken Unicodes (by [ftfy](https://ftfy.readthedocs.io/)) | +| generate_instruction_mapper | General | en, zh | Generate instruction text samples.| | image_blur_mapper | Image | - | Blur images | | image_captioning_from_gpt4v_mapper | Multimodal | - | generate samples whose texts are generated based on gpt-4-visison and the image | | image_captioning_mapper | Multimodal | - | generate samples whose captions are generated based on another model (such as blip2) and the figure within the original sample | @@ -66,6 +68,7 @@ All the specific operators are listed below, each featured with several capabili | image_face_blur_mapper | Image | - | Blur faces detected in images | | nlpaug_en_mapper | General | en | Simply augments texts in English based on the `nlpaug` library | | nlpcda_zh_mapper | General | zh | Simply augments texts in Chinese based on the `nlpcda` library | +| optimize_instruction_mapper | General | en, zh | Optimize instruction text samples.| | punctuation_normalization_mapper | General | en, zh | Normalizes various Unicode punctuations to their ASCII equivalents | | remove_bibliography_mapper | LaTeX | en, zh | Removes the bibliography of TeX documents | | remove_comments_mapper | LaTeX | en, zh | Removes the comments of TeX documents | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index 855d109a7..3d0e33df3 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -11,7 +11,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | 类型 | 数量 | 描述 | |------------------------------------|:--:|---------------| | [ Formatter ]( #formatter ) | 7 | 发现、加载、规范化原始数据 | -| [ Mapper ]( #mapper ) | 43 | 对数据样本进行编辑和转换 | +| [ Mapper ]( #mapper ) | 46 | 对数据样本进行编辑和转换 | | [ Filter ]( #filter ) | 41 | 过滤低质量样本 | | [ Deduplicator ]( #deduplicator ) | 5 | 识别、删除重复样本 | | [ Selector ]( #selector ) | 4 | 基于排序选取高质量样本 | @@ -57,7 +57,9 @@ Data-Juicer 中的算子分为以下 5 种类型。 | clean_ip_mapper | General | en, zh | 删除 IP 地址 | | clean_links_mapper | General, Code | en, zh | 删除链接,例如以 http 或 ftp 开头的 | | expand_macro_mapper | LaTeX | en, zh | 扩展通常在 TeX 文档顶部定义的宏 | +| extract_qa_mapper | General | en, zh | 从文本中抽取问答对 | | fix_unicode_mapper | General | en, zh | 修复损坏的 Unicode(借助 [ftfy](https://ftfy.readthedocs.io/)) | +| generate_instruction_mapper | General | en, zh | 指令扩充,根据种子数据,生成新的样本。 | | image_blur_mapper | Image | - | 对图像进行模糊处理 | | image_captioning_from_gpt4v_mapper | Multimodal | - | 基于gpt-4-vision和图像生成文本 | | image_captioning_mapper | Multimodal | - | 生成样本,其标题是根据另一个辅助模型(例如 blip2)和原始样本中的图形生成的。 | @@ -65,6 +67,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | image_face_blur_mapper | Image | - | 对图像中的人脸进行模糊处理 | | nlpaug_en_mapper | General | en | 使用`nlpaug`库对英语文本进行简单增强 | | nlpcda_zh_mapper | General | zh | 使用`nlpcda`库对中文文本进行简单增强 | +| optimize_instruction_mapper | General | en, zh | 指令优化,优化prompt。| | punctuation_normalization_mapper | General | en, zh | 将各种 Unicode 标点符号标准化为其 ASCII 等效项 | | remove_bibliography_mapper | LaTeX | en, zh | 删除 TeX 文档的参考文献 | | remove_comments_mapper | LaTeX | en, zh | 删除 TeX 文档中的注释 | diff --git a/environments/science_requires.txt b/environments/science_requires.txt index c1350368b..1a63e64e8 100644 --- a/environments/science_requires.txt +++ b/environments/science_requires.txt @@ -24,3 +24,4 @@ simple-aesthetics-predictor scenedetect[opencv] ffmpeg-python opencv-python +vllm diff --git a/tests/format/test_empty_formatter.py b/tests/format/test_empty_formatter.py new file mode 100644 index 000000000..d9400777c --- /dev/null +++ b/tests/format/test_empty_formatter.py @@ -0,0 +1,44 @@ +import os +import unittest + +from data_juicer.format.empty_formatter import EmptyFormatter +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class EmptyFormatterTest(DataJuicerTestCaseBase): + + text_key = 'text' + + def test_empty_dataset(self): + ds_len = 10 + formatter = EmptyFormatter(length=ds_len, feature_keys=[self.text_key]) + ds = formatter.load_dataset() + + self.assertEqual(len(ds), ds_len) + self.assertEqual(list(ds.features.keys()), [self.text_key]) + + for item in ds: + self.assertDictEqual(item, {self.text_key: None}) + + # test map + update_column = {self.text_key: 1} + + def map_fn(sample): + sample.update(update_column) + return sample + + ds = ds.map(map_fn) + self.assertEqual(len(ds), ds_len) + for item in ds: + self.assertDictEqual(item, update_column) + + # test filter + def filter_fn(sample): + return sample[self.text_key] > 2 + + ds = ds.filter(filter_fn) + self.assertEqual(len(ds), 0) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_extract_qa_mapper.py b/tests/ops/mapper/test_extract_qa_mapper.py index 6d659b61f..648996a9f 100644 --- a/tests/ops/mapper/test_extract_qa_mapper.py +++ b/tests/ops/mapper/test_extract_qa_mapper.py @@ -10,14 +10,18 @@ class ExtractQAMapperTest(DataJuicerTestCaseBase): text_key = 'text' - def _run_extract_qa(self, samples): + def _run_extract_qa(self, samples, enable_vllm=False, sampling_params={}, **kwargs): op = ExtractQAMapper( hf_model='alibaba-pai/pai-qwen1_5-7b-doc2qa', - qa_format='chatml' + qa_format='chatml', + enable_vllm=enable_vllm, + sampling_params=sampling_params, + **kwargs ) for sample in samples: result = op.process(sample) out_text = json.loads(result[self.text_key]) + print(f'Output sample: {out_text}') # test one output qa sample qa_sample = out_text[0] @@ -31,6 +35,18 @@ def test_extract_qa(self): }] self._run_extract_qa(samples) + def test_extract_qa_vllm(self): + samples = [ + { + self.text_key: '蒙古国的首都是乌兰巴托(Ulaanbaatar)\n冰岛的首都是雷克雅未克(Reykjavik)\n' + }] + self._run_extract_qa( + samples, + enable_vllm=True, + max_model_len=1024, + max_num_seqs=16, + sampling_params={'temperature': 0.9, 'top_p': 0.95, 'max_tokens': 256}) + if __name__ == '__main__': unittest.main() diff --git a/tests/ops/mapper/test_generate_instruction_mapper.py b/tests/ops/mapper/test_generate_instruction_mapper.py new file mode 100644 index 000000000..0bd7a1099 --- /dev/null +++ b/tests/ops/mapper/test_generate_instruction_mapper.py @@ -0,0 +1,42 @@ +import unittest +import json +from data_juicer.ops.mapper.generate_instruction_mapper import GenerateInstructionMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) + +# Skip tests for this OP in the GitHub actions due to disk space limitation. +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class GenerateInstructionMapperTest(DataJuicerTestCaseBase): + + text_key = 'text' + + def _run_generate_instruction(self, enable_vllm=False): + op = GenerateInstructionMapper( + hf_model='Qwen/Qwen-7B-Chat', + seed_file='demos/data/demo-dataset-chatml.jsonl', + instruct_num=2, + enable_vllm=enable_vllm + ) + + from data_juicer.format.empty_formatter import EmptyFormatter + dataset = EmptyFormatter(3, [self.text_key]).load_dataset() + + dataset = dataset.map(op.process) + + for item in dataset: + out_sample = json.loads(item[self.text_key]) + print(f'Output sample: {out_sample}') + # test one output qa sample + self.assertIn('role', out_sample['messages'][0]) + self.assertIn('content', out_sample['messages'][0]) + + def test_generate_instruction(self): + self._run_generate_instruction() + + def test_generate_instruction_vllm(self): + self._run_generate_instruction(enable_vllm=True) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_optimize_instruction_mapper.py b/tests/ops/mapper/test_optimize_instruction_mapper.py new file mode 100644 index 000000000..7c7b58b4c --- /dev/null +++ b/tests/ops/mapper/test_optimize_instruction_mapper.py @@ -0,0 +1,36 @@ +import unittest +from data_juicer.ops.mapper.optimize_instruction_mapper import OptimizeInstructionMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) + +# Skip tests for this OP in the GitHub actions due to disk space limitation. +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class OptimizeInstructionMapperTest(DataJuicerTestCaseBase): + + text_key = 'text' + + def _run_optimize_instruction(self, enable_vllm=False): + op = OptimizeInstructionMapper( + hf_model='alibaba-pai/Qwen2-7B-Instruct-Refine', + enable_vllm=enable_vllm + ) + + samples = [ + {self.text_key: '鱼香肉丝怎么做?'} + ] + + for sample in samples: + result = op.process(sample) + print(f'Output results: {result}') + self.assertIn(self.text_key, result) + + def test_optimize_instruction(self): + self._run_optimize_instruction() + + def test_optimize_instruction_vllm(self): + self._run_optimize_instruction(enable_vllm=True) + + +if __name__ == '__main__': + unittest.main() From 60a15404d14904e87b315f108e7b4a05fdcc3ca1 Mon Sep 17 00:00:00 2001 From: garyzhang99 <46197280+garyzhang99@users.noreply.github.com> Date: Sun, 1 Sep 2024 12:26:24 +0800 Subject: [PATCH 3/4] rename typo test (#407) --- .../{test_exapnd_macro_mapper.py => test_expand_macro_mapper.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/ops/mapper/{test_exapnd_macro_mapper.py => test_expand_macro_mapper.py} (100%) diff --git a/tests/ops/mapper/test_exapnd_macro_mapper.py b/tests/ops/mapper/test_expand_macro_mapper.py similarity index 100% rename from tests/ops/mapper/test_exapnd_macro_mapper.py rename to tests/ops/mapper/test_expand_macro_mapper.py From 8ea47b7d5818c577981f763710f6255f7281948b Mon Sep 17 00:00:00 2001 From: garyzhang99 <46197280+garyzhang99@users.noreply.github.com> Date: Mon, 2 Sep 2024 10:54:25 +0800 Subject: [PATCH 4/4] use analyzer instead of analyser to maintain consistency (#410) --- configs/demo/{analyser.yaml => analyzer.yaml} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename configs/demo/{analyser.yaml => analyzer.yaml} (100%) diff --git a/configs/demo/analyser.yaml b/configs/demo/analyzer.yaml similarity index 100% rename from configs/demo/analyser.yaml rename to configs/demo/analyzer.yaml