diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 708db5c22..37c541922 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -120,6 +120,10 @@ process: min_ratio: 0.333 # the min aspect ratio of filter range max_ratio: 3.0 # the max aspect ratio of filter range any_or_all: any # keep this sample when any/all images meet the filter condition + - image_size_filter: # filter samples according to the size of images (in bytes) within them + min_size: "0" # the min size of filter range + max_size: "1TB" # the max size of filter range + any_or_all: any # keep this sample when any/all images meet the filter condition - language_id_score_filter: # filter text in specific language with language scores larger than a specific max value lang: en # keep text in what language min_score: 0.8 # the min language scores to filter text diff --git a/data_juicer/ops/filter/__init__.py b/data_juicer/ops/filter/__init__.py index c9332eea0..2d81c28de 100644 --- a/data_juicer/ops/filter/__init__.py +++ b/data_juicer/ops/filter/__init__.py @@ -1,8 +1,8 @@ from . import (alphanumeric_filter, average_line_length_filter, character_repetition_filter, flagged_words_filter, - image_aspect_ratio_filter, language_id_score_filter, - maximum_line_length_filter, perplexity_filter, - special_characters_filter, specified_field_filter, - specified_numeric_field_filter, stopwords_filter, suffix_filter, - text_length_filter, token_num_filter, word_num_filter, - word_repetition_filter) + image_aspect_ratio_filter, image_size_filter, + language_id_score_filter, maximum_line_length_filter, + perplexity_filter, special_characters_filter, + specified_field_filter, specified_numeric_field_filter, + stopwords_filter, suffix_filter, text_length_filter, + token_num_filter, word_num_filter, word_repetition_filter) diff --git a/data_juicer/ops/filter/image_size_filter.py b/data_juicer/ops/filter/image_size_filter.py new file mode 100644 index 000000000..254bf6141 --- /dev/null +++ b/data_juicer/ops/filter/image_size_filter.py @@ -0,0 +1,74 @@ +import numpy as np + +from data_juicer.utils.constant import Fields, StatsKeys +from data_juicer.utils.mm_utils import get_image_size, size_to_bytes + +from ..base_op import OPERATORS, Filter + + +@OPERATORS.register_module('image_size_filter') +class ImageSizeFilter(Filter): + """Keep data samples whose image size (in bytes/kb/MB/...) within a + specific range. + """ + + def __init__(self, + min_size: str = '0', + max_size: str = '1TB', + any_or_all: str = 'any', + *args, + **kwargs): + """ + Initialization method. + + :param min_size: The min image size to keep samples. set to be "0" by + default for no size constraint + :param max_size: The max image size to keep samples. set to be + "1Tb" by default, an approximate for un-limited case + :param any_or_all: keep this sample with 'any' or 'all' strategy of + all images. 'any': keep this sample if any images meet the + condition. 'all': keep this sample only if all images meet the + condition. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.min_size = size_to_bytes(min_size) + self.max_size = size_to_bytes(max_size) + if any_or_all not in ['any', 'all']: + raise ValueError(f'Keep strategy [{any_or_all}] is not supported. ' + f'Can only be one of ["any", "all"].') + self.any = (any_or_all == 'any') + + def compute_stats(self, sample, context=False): + # check if it's computed already + if StatsKeys.image_sizes in sample[Fields.stats]: + return sample + + # there is no image in this sample + if self.image_key not in sample or not sample[self.image_key]: + sample[Fields.stats][StatsKeys.image_sizes] = np.array( + [], dtype=np.float64) + return sample + + # for size calculation, no need to load images into memory + sample[Fields.stats][StatsKeys.image_sizes] = [ + get_image_size(img_path) for img_path in sample[self.image_key] + ] + + return sample + + def process(self, sample): + image_sizes = sample[Fields.stats][StatsKeys.image_sizes] + keep_bools = np.array([ + self.min_size <= image_size <= self.max_size + for image_size in image_sizes + ]) + if len(keep_bools) <= 0: + return True + + # different strategies + if self.any: + return keep_bools.any() + else: + return keep_bools.all() diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index b2f3a362f..216be5ae6 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -28,6 +28,7 @@ class StatsKeys(object): # image aspect_ratios = 'aspect_ratios' + image_sizes = 'image_sizes' class HashKeys(object): diff --git a/data_juicer/utils/mm_utils.py b/data_juicer/utils/mm_utils.py index 6ed0931e0..ea6b2063f 100644 --- a/data_juicer/utils/mm_utils.py +++ b/data_juicer/utils/mm_utils.py @@ -21,3 +21,47 @@ def load_image(path): img_feature = Image() img = img_feature.decode_example(img_feature.encode_example(path)) return img + + +def get_image_size(path): + import os + return os.path.getsize(path) + + +def size_to_bytes(size): + alphabets_list = [char for char in size if char.isalpha()] + numbers_list = [char for char in size if char.isdigit()] + + if len(numbers_list) == 0: + raise ValueError(f'Your input `size` does not contain numbers: {size}') + + size_numbers = int(float(''.join(numbers_list))) + + if len(alphabets_list) == 0: + # by default, if users do not specify the units, the number will be + # regarded as in bytes + return size_numbers + + suffix = ''.join(alphabets_list).lower() + + if suffix == 'kb' or suffix == 'kib': + return size_numbers << 10 + elif suffix == 'mb' or suffix == 'mib': + return size_numbers << 20 + elif suffix == 'gb' or suffix == 'gib': + return size_numbers << 30 + elif suffix == 'tb' or suffix == 'tib': + return size_numbers << 40 + elif suffix == 'pb' or suffix == 'pib': + return size_numbers << 50 + elif suffix == 'eb' or suffix == 'eib': + return size_numbers << 60 + elif suffix == 'zb' or suffix == 'zib': + return size_numbers << 70 + elif suffix == 'yb' or suffix == 'yib': + return size_numbers << 80 + else: + raise ValueError(f'You specified unidentifiable unit: {suffix}, ' + f'expected in [KB, MB, GB, TB, PB, EB, ZB, YB, ' + f'KiB, MiB, GiB, TiB, PiB, EiB, ZiB, YiB], ' + f'(case insensitive, counted by *Bytes*).') diff --git a/demos/overview_scan/app.py b/demos/overview_scan/app.py index 378b8f502..5a64b9bae 100644 --- a/demos/overview_scan/app.py +++ b/demos/overview_scan/app.py @@ -89,7 +89,7 @@ |-----------------------------------|:------:|-------------------------------------------------| | Formatter | 7 | Discovers, loads, and canonicalizes source data | | Mapper | 21 | Edits and transforms samples | -| Filter | 17 | Filters out low-quality samples | +| Filter | 18 | Filters out low-quality samples | | Deduplicator | 3 | Detects and removes duplicate samples | | Selector | 2 | Selects top samples based on ranking | ''' @@ -142,6 +142,7 @@ | character_repetition_filter | General | en, zh | Keeps samples with char-level n-gram repetition ratio within the specified range | | flagged_words_filter | General | en, zh | Keeps samples with flagged-word ratio below the specified threshold | | image_aspect_ratio_filter | Image | - | Keeps samples contains images with aspect ratios within specific range | +| image_size_filter | Image | - | Keeps samples contains images whose size in bytes are within specific range | | language_id_score_filter | General | en, zh | Keeps samples of the specified language, judged by a predicted confidence score | | maximum_line_length_filter | Code | en, zh | Keeps samples with maximum line length within the specified range | | perplexity_filter | General | en, zh | Keeps samples with perplexity score below the specified threshold | diff --git a/docs/Operators.md b/docs/Operators.md index 78abeb495..4f72b4c34 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -11,7 +11,7 @@ The operators in Data-Juicer are categorized into 5 types. |-----------------------------------|:------:|-------------------------------------------------| | [ Formatter ]( #formatter ) | 7 | Discovers, loads, and canonicalizes source data | | [ Mapper ]( #mapper ) | 21 | Edits and transforms samples | -| [ Filter ]( #filter ) | 17 | Filters out low-quality samples | +| [ Filter ]( #filter ) | 18 | Filters out low-quality samples | | [ Deduplicator ]( #deduplicator ) | 3 | Detects and removes duplicate samples | | [ Selector ]( #selector ) | 2 | Selects top samples based on ranking | @@ -77,6 +77,7 @@ All the specific operators are listed below, each featured with several capabili | character_repetition_filter | General | en, zh | Keeps samples with char-level n-gram repetition ratio within the specified range | | flagged_words_filter | General | en, zh | Keeps samples with flagged-word ratio below the specified threshold | | image_aspect_ratio_filter | Image | - | Keeps samples contains images with aspect ratios within specific range | +| image_size_filter | Image | - | Keeps samples contains images whose size in bytes are within specific range | | language_id_score_filter | General | en, zh | Keeps samples of the specified language, judged by a predicted confidence score | | maximum_line_length_filter | Code | en, zh | Keeps samples with maximum line length within the specified range | | perplexity_filter | General | en, zh | Keeps samples with perplexity score below the specified threshold | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index cf3421d94..0be43ae2d 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -73,6 +73,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | character_repetition_filter | General | en, zh | 保留 char-level n-gram 重复比率在指定范围内的样本 | | flagged_words_filter | General | en, zh | 保留使标记字比率保持在指定阈值以下的样本 | | image_aspect_ratio_filter | Image | - | 保留样本中包含的图片的宽高比在指定范围内的样本 | +| image_size_filter | Image | - | 保留样本中包含的图片的大小(bytes)在指定范围内的样本 | | language_id_score_filter | General | en, zh | 保留特定语言的样本,通过预测的置信度得分来判断 | | maximum_line_length_filter | Code | en, zh | 保留最大行长度在指定范围内的样本 | | perplexity_filter | General | en, zh | 保留困惑度低于指定阈值的样本 | diff --git a/tests/ops/filter/test_image_size_filter.py b/tests/ops/filter/test_image_size_filter.py new file mode 100644 index 000000000..b9dc8fa01 --- /dev/null +++ b/tests/ops/filter/test_image_size_filter.py @@ -0,0 +1,118 @@ +import os +import unittest + +from datasets import Dataset + +from data_juicer.ops.filter.image_size_filter import ImageSizeFilter +from data_juicer.utils.constant import Fields + + +class ImageSizeFilterTest(unittest.TestCase): + + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), + '..', 'data') + img1_path = os.path.join(data_path, 'img1.png') + img2_path = os.path.join(data_path, 'img2.jpg') + img3_path = os.path.join(data_path, 'img3.jpg') + + def _run_image_size_filter(self, + dataset: Dataset, target_list, + op): + if Fields.stats not in dataset.features: + dataset = dataset.add_column(name=Fields.stats, + column=[{}] * dataset.num_rows) + dataset = dataset.map(op.compute_stats) + dataset = dataset.filter(op.process) + dataset = dataset.select_columns(column_names=[op.image_key]) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_min_max(self): + + ds_list = [{ + 'images': [self.img1_path] # 171KB + }, { + 'images': [self.img2_path] # 189KB + }, { + 'images': [self.img3_path] # 114KB + }] + tgt_list = [{ + 'images': [self.img1_path] + }] + dataset = Dataset.from_list(ds_list) + op = ImageSizeFilter(min_size="120kb", max_size="180KB") + self._run_image_size_filter(dataset, tgt_list, op) + + def test_min(self): + + ds_list = [{ + 'images': [self.img1_path] # 171KB + }, { + 'images': [self.img2_path] # 189KB + }, { + 'images': [self.img3_path] # 114KB + }] + tgt_list = [{ + 'images': [self.img1_path] + }, { + 'images': [self.img2_path] + }] + dataset = Dataset.from_list(ds_list) + op = ImageSizeFilter(min_size="120kib") + self._run_image_size_filter(dataset, tgt_list, op) + + def test_max(self): + + ds_list = [{ + 'images': [self.img1_path] # 171KB + }, { + 'images': [self.img2_path] # 189KB + }, { + 'images': [self.img3_path] # 114KB + }] + tgt_list = [{ + 'images': [self.img1_path] + }, { + 'images': [self.img3_path] + }] + dataset = Dataset.from_list(ds_list) + op = ImageSizeFilter(max_size="180KiB") + self._run_image_size_filter(dataset, tgt_list, op) + + def test_any(self): + + ds_list = [{ + 'images': [self.img1_path, self.img2_path] + }, { + 'images': [self.img2_path, self.img3_path] + }, { + 'images': [self.img1_path, self.img3_path] + }] + tgt_list = [{ + 'images': [self.img1_path, self.img2_path] + }, { + 'images': [self.img1_path, self.img3_path] + }] + dataset = Dataset.from_list(ds_list) + op = ImageSizeFilter(min_size="120kb", max_size="180KB", + any_or_all='any') + self._run_image_size_filter(dataset, tgt_list, op) + + def test_all(self): + + ds_list = [{ + 'images': [self.img1_path, self.img2_path] + }, { + 'images': [self.img2_path, self.img3_path] + }, { + 'images': [self.img1_path, self.img3_path] + }] + tgt_list = [] + dataset = Dataset.from_list(ds_list) + op = ImageSizeFilter(min_size="120kb", max_size="180KB", + any_or_all='all') + self._run_image_size_filter(dataset, tgt_list, op) + + +if __name__ == '__main__': + unittest.main()