Skip to content

Commit

Permalink
Enhance/support prompt for caption generation (#191)
Browse files Browse the repository at this point in the history
* * support prompt for generate_caption_mapper

* * update docs

* * support prompt for generate_caption_mapper

* + Add a warning when both prompt and prompt_key are set
  • Loading branch information
HYLcool authored Jan 19, 2024
1 parent 2d4cee5 commit ab4d3c8
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
2 changes: 2 additions & 0 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ process:
caption_num: 1 # how many candidate captions to generate for each image
keep_candidate_mode: 'random_any' # retain strategy for the generated $caption_num$ candidates. should be in ["random_any", "similar_one_simhash", "all"].
keep_original_sample: true # whether to keep the original sample. If it's set to False, there will be only generated captions in the final datasets and the original captions will be removed. It's True in default.
prompt: null # a string prompt to guide the generation of blip2 model for all samples globally. It's None in default, which means no prompt provided.
prompt_key: null # the key name of fields in samples to store prompts for each sample. It's used for set different prompts for different samples. If it's none, use prompt in parameter "prompt". It's None in default.
- image_blur_mapper: # mapper to blur images.
p: 0.2 # probability of the image being blured
blur_type: 'gaussian' # type of blur kernel, including ['mean', 'box', 'gaussian']
Expand Down
30 changes: 30 additions & 0 deletions data_juicer/ops/mapper/generate_caption_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
from jsonargparse.typing import PositiveInt
from loguru import logger

from data_juicer.utils.availability_utils import AvailabilityChecking
from data_juicer.utils.constant import HashKeys
Expand Down Expand Up @@ -38,6 +39,8 @@ def __init__(self,
caption_num: PositiveInt = 1,
keep_candidate_mode: str = 'random_any',
keep_original_sample: bool = True,
prompt: str = None,
prompt_key: str = None,
*args,
**kwargs):
"""
Expand All @@ -64,6 +67,13 @@ def __init__(self,
it's set to False, there will be only generated captions in the
final datasets and the original captions will be removed. It's True
in default.
:param prompt: a string prompt to guide the generation of blip2 model
for all samples globally. It's None in default, which means no
prompt provided.
:param prompt_key: the key name of fields in samples to store prompts
for each sample. It's used for set different prompts for different
samples. If it's none, use prompt in parameter "prompt". It's None
in default.
:param args: extra args
:param kwargs: extra args
"""
Expand All @@ -87,6 +97,8 @@ def __init__(self,
self.caption_num = caption_num
self.keep_candidate_mode = keep_candidate_mode
self.keep_original_sample = keep_original_sample
self.prompt = prompt
self.prompt_key = prompt_key
self.extra_args = kwargs

if keep_candidate_mode in ['random_any', 'similar_one_simhash']:
Expand All @@ -96,6 +108,12 @@ def __init__(self,
else:
self.num_newly_generated_samples = 0

# report a warning when both prompt and prompt_key are set
if self.prompt and self.prompt_key:
logger.warning(
'Both the parameter `prompt` and `prompt_key` are '
'set. Data-Juicer will consider `prompt_key` first.')

def _process_single_sample(self, ori_sample):
"""
Expand Down Expand Up @@ -153,7 +171,19 @@ def _process_single_sample(self, ori_sample):
# generated_text_candidates_single_chunk[i][j] indicates
# the $i$-th generated candidate for the $j$-th image

# construct prompts
if self.prompt_key \
and isinstance(ori_sample[self.prompt_key], str):
# check prompt_key is not None, and it's a str in the sample
prompt_texts = [ori_sample[self.prompt_key]] * len(image_chunk)
elif self.prompt and isinstance(self.prompt, str):
# check prompt is not None, and it's a str
prompt_texts = [self.prompt] * len(image_chunk)
else:
prompt_texts = None

inputs = self.img_processor_in_ctx(images=image_chunk,
text=prompt_texts,
return_tensors='pt')
for i in range(self.caption_num):
generated_ids = self.model_in_ctx.generate(**inputs,
Expand Down

0 comments on commit ab4d3c8

Please sign in to comment.