Skip to content

Commit

Permalink
modify image_dedup
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhesen committed Nov 15, 2023
1 parent a11c68f commit 5580fc9
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
8 changes: 4 additions & 4 deletions data_juicer/ops/deduplicator/image_deduplicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from data_juicer.utils.constant import Fields, HashKeys
from data_juicer.utils.mm_utils import load_image

from ..base_op import OPERATORS, Filter
from ..base_op import OPERATORS, Deduplicator
from ..op_fusion import LOADED_IMAGES

HASH_METHOD = {
Expand All @@ -20,7 +20,7 @@

@OPERATORS.register_module('image_deduplicator')
@LOADED_IMAGES.register_module('image_deduplicator')
class ImageDeduplicator(Filter):
class ImageDeduplicator(Deduplicator):
"""
Deduplicator to deduplicate samples at document-level using exact matching
of images between documents.
Expand All @@ -38,7 +38,7 @@ def __init__(self, method: str = 'phash', *args, **kwargs):
if method not in HASH_METHOD.keys():
raise ValueError(f'Keep strategy [{method}] is not supported. '
f'Can only be one of {HASH_METHOD.keys()}.')
self.phasher = HASH_METHOD[method]
self.hasher = HASH_METHOD[method]

def compute_hash(self, sample, context=False):
# check if it's computed already
Expand Down Expand Up @@ -69,7 +69,7 @@ def compute_hash(self, sample, context=False):

# compute hash
for key in images:
sample[HashKeys.imagehash] += self.phasher.encode_image(
sample[HashKeys.imagehash] += self.hasher.encode_image(
image_array=np.array(images[key]))
return sample

Expand Down
4 changes: 4 additions & 0 deletions tests/ops/deduplicator/test_image_deduplicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@ class ImageDeduplicatorTest(unittest.TestCase):
img1_path = os.path.join(data_path, 'img1.png')
img2_path = os.path.join(data_path, 'img2.jpg')
img3_path = os.path.join(data_path, 'img3.jpg')
# img4.png is a duplicate sample of img1.png
img4_path = os.path.join(data_path, 'img4.png')
# img5.jpg is a duplicate sample of img2.jpg
img5_path = os.path.join(data_path, 'img5.jpg')
# img6.jpg is a duplicate sample of img3.jpg
img6_path = os.path.join(data_path, 'img6.jpg')
# img7.jpg is a duplicate sample of img6.jpg
img7_path = os.path.join(data_path, 'img7.jpg')


Expand Down

0 comments on commit 5580fc9

Please sign in to comment.