Skip to content

Commit

Permalink
* init image tagging mapper
Browse files Browse the repository at this point in the history
  • Loading branch information
HYLcool committed Sep 9, 2024
1 parent adb4ac9 commit 7e653c6
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 4 deletions.
11 changes: 7 additions & 4 deletions data_juicer/ops/mapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
generate_instruction_mapper, image_blur_mapper,
image_captioning_from_gpt4v_mapper, image_captioning_mapper,
image_diffusion_mapper, image_face_blur_mapper,
nlpaug_en_mapper, nlpcda_zh_mapper, optimize_instruction_mapper,
punctuation_normalization_mapper, remove_bibliography_mapper,
remove_comments_mapper, remove_header_mapper,
remove_long_words_mapper, remove_non_chinese_character_mapper,
image_tagging_mapper, nlpaug_en_mapper, nlpcda_zh_mapper,
optimize_instruction_mapper, punctuation_normalization_mapper,
remove_bibliography_mapper, remove_comments_mapper,
remove_header_mapper, remove_long_words_mapper,
remove_non_chinese_character_mapper,
remove_repeat_sentences_mapper, remove_specific_chars_mapper,
remove_table_text_mapper,
remove_words_with_incorrect_substrings_mapper,
Expand Down Expand Up @@ -41,6 +42,7 @@
from .image_captioning_mapper import ImageCaptioningMapper
from .image_diffusion_mapper import ImageDiffusionMapper
from .image_face_blur_mapper import ImageFaceBlurMapper
from .image_tagging_mapper import ImageTaggingMapper
from .nlpaug_en_mapper import NlpaugEnMapper
from .nlpcda_zh_mapper import NlpcdaZhMapper
from .optimize_instruction_mapper import OptimizeInstructionMapper
Expand Down Expand Up @@ -123,6 +125,7 @@
'AudioFFmpegWrappedMapper',
'VideoSplitByDurationMapper',
'VideoFaceBlurMapper',
'ImageTaggingMapper',
]

# yapf: enable
77 changes: 77 additions & 0 deletions data_juicer/ops/mapper/image_tagging_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from collections import Counter

from data_juicer.utils.availability_utils import AvailabilityChecking
from data_juicer.utils.constant import Fields
from data_juicer.utils.mm_utils import load_data_with_context, load_image
from data_juicer.utils.model_utils import get_model, prepare_model

from ..base_op import OPERATORS, UNFORKABLE, Mapper
from ..op_fusion import LOADED_IMAGES

OP_NAME = 'image_tagging_mapper'

with AvailabilityChecking(
['torch', 'git+https://github.com/xinyu1205/recognize-anything.git'],
OP_NAME):
import ram # noqa: F401
import torch

# avoid hanging when calling recognizeAnything in multiprocessing
torch.set_num_threads(1)


@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
@LOADED_IMAGES.register_module(OP_NAME)
class ImageTaggingMapper(Mapper):
"""Mapper to generate image tags.
"""

_accelerator = 'cuda'

def __init__(self, *args, **kwargs):
"""
Initialization method.
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)
self.model_key = prepare_model(
model_type='recognizeAnything',
pretrained_model_name_or_path='ram_plus_swin_large_14m.pth',
input_size=384)
from ram import get_transform
self.transform = get_transform(image_size=384)

def process(self, sample, rank=None, context=False):
# check if it's generated already
if Fields.image_tags in sample:
return sample

# there is no image in this sample
if self.image_key not in sample or not sample[self.image_key]:
sample[Fields.image_tags] = []
return sample

# load images
loaded_image_keys = sample[self.image_key]
sample, images = load_data_with_context(sample, context,
loaded_image_keys, load_image)

model = get_model(self.model_key, rank, self.use_cuda())
image_tags = []
for _, value in enumerate(loaded_image_keys):
image = images[value]

image_tensor = torch.unsqueeze(self.transform(image)).to(
next(model.parameters()).device)
with torch.no_grad():
tags, _ = model.generate_tag(image_tensor)

words = [word.strip() for tag in tags for word in tag.split('|')]
word_count = Counter(words)
sorted_word_list = [item for item, _ in word_count.most_common()]
image_tags.append(sorted_word_list)

sample[Fields.image_tags] = image_tags
return sample
2 changes: 2 additions & 0 deletions data_juicer/utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class Fields(object):
# video_frame_tags
video_frame_tags = DEFAULT_PREFIX + 'video_frame_tags__'
video_audio_tags = DEFAULT_PREFIX + 'video_audio_tags__'
# image_tags
image_tags = DEFAULT_PREFIX + 'image_tags__'

# the name of the original file from which this sample was derived.
source_file = DEFAULT_PREFIX + 'source_file__'
Expand Down

0 comments on commit 7e653c6

Please sign in to comment.