Skip to content

Commit

Permalink
Feat/analyzer enhance & multimodal + text duduplication (#313)
Browse files Browse the repository at this point in the history
* + support to set percentiles to analyser
+ support to decide whether to export dataset in analyser with configs

* * support multimodal data deduplication together with texts
  • Loading branch information
HYLcool authored May 14, 2024
1 parent 0d81116 commit 370e620
Show file tree
Hide file tree
Showing 8 changed files with 236 additions and 22 deletions.
4 changes: 4 additions & 0 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ executor_type: default # type of executor,
ray_address: auto # the address of the Ray cluster.

# only for data analysis
percentiles: [0.25, 0.5, 0.75] # percentiles to analyse the dataset distribution
export_original_dataset: false # whether to export the original dataset with stats. If you only need the stats of the dataset, setting it to false could speed up the exporting.
save_stats_in_one_file: false # whether to store all stats result into one file

# for sandbox or hpo
Expand Down Expand Up @@ -478,7 +480,9 @@ process:
ignore_pattern: null # whether to ignore sub-strings with specific pattern when computing simhash.
- image_deduplicator: # deduplicator to deduplicate samples at document-level using exact matching of images between documents.
method: phash # hash method for image. One of [phash, dhash, whash, ahash]
consider_text: false # whether to consider text hash together with image hash when applying deduplication.
- video_deduplicator: # deduplicator to deduplicate samples at document-level using exact matching of videos between documents.
consider_text: false # whether to consider text hash together with video hash when applying deduplication.
- ray_video_deduplicator: # the simple video deduplicator that can run on multi-nodes using md5 hashing exact matching method
redis_host: 'redis_host' # the host of the redis instance
redis_port: 6380 # the port of redis instance, please note that the default port of redis is 6379 which is the same as default port for ray, so we need to modify the default redis config to use it in other port
Expand Down
13 changes: 13 additions & 0 deletions data_juicer/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,19 @@ def init_configs(args=None):
type=List[Dict],
help='List of several operators with their arguments, these ops will '
'be applied to dataset in order')
parser.add_argument(
'--percentiles',
type=List[float],
default=[],
help='Percentiles to analyse the dataset distribution. Only used in '
'Analysis.')
parser.add_argument(
'--export_original_dataset',
type=bool,
default=False,
help='whether to export the original dataset with stats. If you only '
'need the stats of the dataset, setting it to false could speed '
'up the exporting..')
parser.add_argument(
'--save_stats_in_one_file',
type=bool,
Expand Down
20 changes: 12 additions & 8 deletions data_juicer/core/analyser.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,14 @@ def __init__(self, cfg=None):
# (export_ds=False). Instead, only need to export stats
# (export_stats=True).
logger.info('Preparing exporter...')
self.exporter = Exporter(self.cfg.export_path,
self.cfg.export_shard_size,
self.cfg.export_in_parallel,
self.cfg.np,
export_ds=False,
export_stats=True)
self.exporter = Exporter(
self.cfg.export_path,
self.cfg.export_shard_size,
self.cfg.export_in_parallel,
self.cfg.np,
export_ds=self.cfg.export_original_dataset,
keep_stats_in_res_ds=self.cfg.export_original_dataset,
export_stats=True)

# parsed_res
self.overall_result = None
Expand Down Expand Up @@ -121,8 +123,10 @@ def run(self, load_data_np=None, skip_export=False):

logger.info('Applying overall analysis on stats...')
overall_analysis = OverallAnalysis(dataset, self.analysis_path)
self.overall_result = overall_analysis.analyse(num_proc=self.cfg.np,
skip_export=skip_export)
self.overall_result = overall_analysis.analyse(
percentiles=self.cfg.percentiles,
num_proc=self.cfg.np,
skip_export=skip_export)

logger.info(f'The overall analysis results are: {self.overall_result}')

Expand Down
7 changes: 5 additions & 2 deletions data_juicer/core/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,11 @@ def _export_impl(self, dataset, export_path, suffix, export_stats=True):
dataset = dataset.remove_columns(removed_fields)
if not self.keep_hashes_in_res_ds:
extra_fields = {
HashKeys.hash, HashKeys.minhash, HashKeys.simhash,
HashKeys.imagehash
HashKeys.hash,
HashKeys.minhash,
HashKeys.simhash,
HashKeys.imagehash,
HashKeys.videohash,
}
feature_fields = set(dataset.features.keys())
removed_fields = extra_fields.intersection(feature_fields)
Expand Down
32 changes: 27 additions & 5 deletions data_juicer/ops/deduplicator/image_deduplicator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from typing import Dict, Set
from typing import Dict, Set, Tuple

import numpy as np

Expand All @@ -9,6 +9,7 @@

from ..base_op import OPERATORS, Deduplicator
from ..op_fusion import LOADED_IMAGES
from .document_deduplicator import DocumentDeduplicator

OP_NAME = 'image_deduplicator'

Expand Down Expand Up @@ -38,11 +39,17 @@ class ImageDeduplicator(Deduplicator):
of images between documents.
"""

def __init__(self, method: str = 'phash', *args, **kwargs):
def __init__(self,
method: str = 'phash',
consider_text: bool = False,
*args,
**kwargs):
"""
Initialization method.
:param method: hash method for image
:param consider_text: whether to consider text hash together with image
hash when applying deduplication.
:param args: extra args
:param kwargs: extra args
"""
Expand All @@ -51,6 +58,10 @@ def __init__(self, method: str = 'phash', *args, **kwargs):
raise ValueError(f'Keep strategy [{method}] is not supported. '
f'Can only be one of {HASH_METHOD}.')
self.hasher = get_hash_method(method)()
self.consider_text = consider_text
self.text_dedup_op = None
if self.consider_text:
self.text_dedup_op = DocumentDeduplicator(**kwargs)

def compute_hash(self, sample, context=False):
# check if it's computed already
Expand All @@ -71,6 +82,8 @@ def compute_hash(self, sample, context=False):
for key in images:
sample[HashKeys.imagehash] += self.hasher.encode_image(
image_array=np.array(images[key]))
if self.consider_text:
sample = self.text_dedup_op.compute_hash(sample)
return sample

def process(self, dataset, show_num=0):
Expand All @@ -89,8 +102,14 @@ def process(self, dataset, show_num=0):
dup_hashes = None
if show_num > 0:
# sample duplicate pairs
hash2ids: Dict[int, Set[int]] = defaultdict(set)
for sid, hash_val in enumerate(dataset[HashKeys.imagehash]):
if self.consider_text:
hash2ids: Dict[Tuple[int], Set[int]] = defaultdict(set)
hashes = zip(dataset[HashKeys.imagehash],
dataset[HashKeys.hash])
else:
hash2ids: Dict[int, Set[int]] = defaultdict(set)
hashes = dataset[HashKeys.imagehash]
for sid, hash_val in enumerate(hashes):
if hash_val:
hash2ids[hash_val].add(sid)
dup_samples = sorted(list(hash2ids.items()),
Expand All @@ -101,7 +120,10 @@ def process(self, dataset, show_num=0):
][:show_num])

def _filter_dup_helper(sample, hashes):
hash = sample[HashKeys.imagehash]
if self.consider_text:
hash = (sample[HashKeys.imagehash], sample[HashKeys.hash])
else:
hash = sample[HashKeys.imagehash]
if not hash:
return True
if show_num > 0 and hash in dup_hashes \
Expand Down
28 changes: 23 additions & 5 deletions data_juicer/ops/deduplicator/video_deduplicator.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import hashlib
from collections import defaultdict
from typing import Dict, Set
from typing import Dict, Set, Tuple

from data_juicer.utils.constant import HashKeys
from data_juicer.utils.mm_utils import load_data_with_context, load_video

from ..base_op import OPERATORS, Deduplicator
from ..op_fusion import LOADED_VIDEOS
from .document_deduplicator import DocumentDeduplicator

OP_NAME = 'video_deduplicator'

Expand All @@ -19,14 +20,20 @@ class VideoDeduplicator(Deduplicator):
of videos between documents.
"""

def __init__(self, *args, **kwargs):
def __init__(self, consider_text: bool = False, *args, **kwargs):
"""
Initialization.
:param consider_text: whether to consider text hash together with video
hash when applying deduplication.
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)
self.consider_text = consider_text
self.text_dedup_op = None
if self.consider_text:
self.text_dedup_op = DocumentDeduplicator(**kwargs)

def compute_hash(self, sample, context=False):
# check if it's computed already
Expand All @@ -52,6 +59,8 @@ def compute_hash(self, sample, context=False):
md5_hash.update(bytes(packet))

sample[HashKeys.videohash] = md5_hash.hexdigest()
if self.consider_text:
sample = self.text_dedup_op.compute_hash(sample)
return sample

def process(self, dataset, show_num=0):
Expand All @@ -70,8 +79,14 @@ def process(self, dataset, show_num=0):
dup_hashes = None
if show_num > 0:
# sample duplicate pairs
hash2ids: Dict[int, Set[int]] = defaultdict(set)
for sid, hash_val in enumerate(dataset[HashKeys.videohash]):
if self.consider_text:
hash2ids: Dict[Tuple[int], Set[int]] = defaultdict(set)
hashes = zip(dataset[HashKeys.videohash],
dataset[HashKeys.hash])
else:
hash2ids: Dict[int, Set[int]] = defaultdict(set)
hashes = dataset[HashKeys.videohash]
for sid, hash_val in enumerate(hashes):
if hash_val:
hash2ids[hash_val].add(sid)
dup_samples = sorted(list(hash2ids.items()),
Expand All @@ -82,7 +97,10 @@ def process(self, dataset, show_num=0):
][:show_num])

def _filter_dup_helper(sample, hashes):
hash = sample[HashKeys.videohash]
if self.consider_text:
hash = (sample[HashKeys.videohash], sample[HashKeys.hash])
else:
hash = sample[HashKeys.videohash]
if not hash:
return True
if show_num > 0 and hash in dup_hashes \
Expand Down
77 changes: 76 additions & 1 deletion tests/ops/deduplicator/test_image_deduplicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ class ImageDeduplicatorTest(DataJuicerTestCaseBase):
os.symlink(img6_path, img7_path)

def _run_image_deduplicator(self, dataset: Dataset, target_list, op):
key_list = [op.image_key, op.text_key] \
if op.consider_text else [op.image_key]

dataset = dataset.map(op.compute_hash)
dataset, _ = op.process(dataset)
dataset = dataset.select_columns(column_names=[op.image_key])
dataset = dataset.select_columns(column_names=key_list)
res_list = dataset.to_list()
self.assertEqual(res_list, target_list)

Expand Down Expand Up @@ -101,6 +103,50 @@ def test_3(self):
op = ImageDeduplicator()
self._run_image_deduplicator(dataset, tgt_list, op)

def test_3_consider_text(self):

ds_list = [{
'images': [self.img1_path],
'text': '<video> text1'
}, {
'images': [self.img2_path],
'text': '<video> text2'
}, {
'images': [self.img3_path],
'text': '<video> text3'
}, {
'images': [self.img4_path],
'text': '<video> text1'
}, {
'images': [self.img5_path],
'text': '<video> text5'
}, {
'images': [self.img6_path],
'text': '<video> text3'
}, {
'images': [self.img7_path],
'text': '<video> text7'
}]
tgt_list = [{
'images': [self.img1_path],
'text': '<video> text1'
}, {
'images': [self.img2_path],
'text': '<video> text2'
}, {
'images': [self.img3_path],
'text': '<video> text3'
}, {
'images': [self.img5_path],
'text': '<video> text5'
}, {
'images': [self.img7_path],
'text': '<video> text7'
}]
dataset = Dataset.from_list(ds_list)
op = ImageDeduplicator(consider_text=True)
self._run_image_deduplicator(dataset, tgt_list, op)

def test_4(self):

ds_list = [{
Expand All @@ -121,6 +167,35 @@ def test_4(self):
op = ImageDeduplicator()
self._run_image_deduplicator(dataset, tgt_list, op)

def test_4_consider_text(self):

ds_list = [{
'images': [self.img1_path, self.img2_path, self.img3_path],
'text': '<image> text1 <image> text2 <image> text3',
}, {
'images': [self.img4_path, self.img5_path, self.img6_path],
'text': '<image> text1 <image> text5 <image> text3',
}, {
'images': [self.img7_path],
'text': '<image> text6',
}, {
'images': [self.img6_path],
'text': '<image> text6',
}]
tgt_list = [{
'images': [self.img1_path, self.img2_path, self.img3_path],
'text': '<image> text1 <image> text2 <image> text3',
}, {
'images': [self.img4_path, self.img5_path, self.img6_path],
'text': '<image> text1 <image> text5 <image> text3',
}, {
'images': [self.img7_path],
'text': '<image> text6',
}]
dataset = Dataset.from_list(ds_list)
op = ImageDeduplicator(consider_text=True)
self._run_image_deduplicator(dataset, tgt_list, op)

def test_5(self):

ds_list = [{
Expand Down
Loading

0 comments on commit 370e620

Please sign in to comment.