diff --git a/configs/config_all.yaml b/configs/config_all.yaml index fbe4b4780..df4aad524 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -55,6 +55,8 @@ process: caption_num: 1 # how many candidate captions to generate for each image keep_candidate_mode: 'random_any' # retain strategy for the generated $caption_num$ candidates. should be in ["random_any", "similar_one_simhash", "all"]. keep_original_sample: true # whether to keep the original sample. If it's set to False, there will be only generated captions in the final datasets and the original captions will be removed. It's True in default. + prompt: null # a string prompt to guide the generation of blip2 model for all samples globally. It's None in default, which means no prompt provided. + prompt_key: null # the key name of fields in samples to store prompts for each sample. It's used for set different prompts for different samples. If it's none, use prompt in parameter "prompt". It's None in default. - image_blur_mapper: # mapper to blur images. p: 0.2 # probability of the image being blured blur_type: 'gaussian' # type of blur kernel, including ['mean', 'box', 'gaussian'] diff --git a/data_juicer/ops/mapper/generate_caption_mapper.py b/data_juicer/ops/mapper/generate_caption_mapper.py index 2303a2b60..056ebe20c 100644 --- a/data_juicer/ops/mapper/generate_caption_mapper.py +++ b/data_juicer/ops/mapper/generate_caption_mapper.py @@ -3,6 +3,7 @@ import numpy as np from jsonargparse.typing import PositiveInt +from loguru import logger from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import HashKeys @@ -38,6 +39,8 @@ def __init__(self, caption_num: PositiveInt = 1, keep_candidate_mode: str = 'random_any', keep_original_sample: bool = True, + prompt: str = None, + prompt_key: str = None, *args, **kwargs): """ @@ -64,6 +67,13 @@ def __init__(self, it's set to False, there will be only generated captions in the final datasets and the original captions will be removed. It's True in default. + :param prompt: a string prompt to guide the generation of blip2 model + for all samples globally. It's None in default, which means no + prompt provided. + :param prompt_key: the key name of fields in samples to store prompts + for each sample. It's used for set different prompts for different + samples. If it's none, use prompt in parameter "prompt". It's None + in default. :param args: extra args :param kwargs: extra args """ @@ -87,6 +97,8 @@ def __init__(self, self.caption_num = caption_num self.keep_candidate_mode = keep_candidate_mode self.keep_original_sample = keep_original_sample + self.prompt = prompt + self.prompt_key = prompt_key self.extra_args = kwargs if keep_candidate_mode in ['random_any', 'similar_one_simhash']: @@ -96,6 +108,12 @@ def __init__(self, else: self.num_newly_generated_samples = 0 + # report a warning when both prompt and prompt_key are set + if self.prompt and self.prompt_key: + logger.warning( + 'Both the parameter `prompt` and `prompt_key` are ' + 'set. Data-Juicer will consider `prompt_key` first.') + def _process_single_sample(self, ori_sample): """ @@ -153,7 +171,19 @@ def _process_single_sample(self, ori_sample): # generated_text_candidates_single_chunk[i][j] indicates # the $i$-th generated candidate for the $j$-th image + # construct prompts + if self.prompt_key \ + and isinstance(ori_sample[self.prompt_key], str): + # check prompt_key is not None, and it's a str in the sample + prompt_texts = [ori_sample[self.prompt_key]] * len(image_chunk) + elif self.prompt and isinstance(self.prompt, str): + # check prompt is not None, and it's a str + prompt_texts = [self.prompt] * len(image_chunk) + else: + prompt_texts = None + inputs = self.img_processor_in_ctx(images=image_chunk, + text=prompt_texts, return_tensors='pt') for i in range(self.caption_num): generated_ids = self.model_in_ctx.generate(**inputs,