Skip to content

Commit

Permalink
Feature/image shape filter (#74)
Browse files Browse the repository at this point in the history
* + Add new OP: image_shape_filter

* * avoid W605 warning

* * pyarrow should be <= 13.0.0 to avoid useless FutureWarning
  • Loading branch information
HYLcool authored Nov 15, 2023
1 parent 50f1aca commit de984f3
Show file tree
Hide file tree
Showing 10 changed files with 252 additions and 6 deletions.
8 changes: 7 additions & 1 deletion configs/config_all.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Process config example including:
# - all global arguments
# - all ops and their default arguments
# - all ops and their arguments

# global parameters
project_name: 'all' # project name for distinguish your configs
Expand Down Expand Up @@ -126,6 +126,12 @@ process:
min_ratio: 0.333 # the min aspect ratio of filter range
max_ratio: 3.0 # the max aspect ratio of filter range
any_or_all: any # keep this sample when any/all images meet the filter condition
- image_shape_filter: # filter samples according to the widths and heights of images in them
min_width: 200 # the min width of width filter range
max_width: 5000 # the max width of width filter range
min_height: 200 # the min height of height filter range
max_height: 5000 # the max height of height filter range
any_or_all: any # keep this sample when any/all images meet the filter condition
- image_size_filter: # filter samples according to the size of images (in bytes) within them
min_size: "0" # the min size of filter range
max_size: "1TB" # the max size of filter range
Expand Down
2 changes: 1 addition & 1 deletion data_juicer/ops/filter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from . import (alphanumeric_filter, average_line_length_filter,
character_repetition_filter, clip_similarity_filter,
flagged_words_filter, image_aspect_ratio_filter,
image_size_filter, language_id_score_filter,
image_shape_filter, image_size_filter, language_id_score_filter,
maximum_line_length_filter, perplexity_filter,
special_characters_filter, specified_field_filter,
specified_numeric_field_filter, stopwords_filter, suffix_filter,
Expand Down
107 changes: 107 additions & 0 deletions data_juicer/ops/filter/image_shape_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import sys

import numpy as np
from jsonargparse.typing import PositiveInt

from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.mm_utils import load_image

from ..base_op import OPERATORS, Filter
from ..op_fusion import LOADED_IMAGES


@OPERATORS.register_module('image_shape_filter')
@LOADED_IMAGES.register_module('image_shape_filter')
class ImageShapeFilter(Filter):
"""Filter to keep samples with image shape (w, h) within specific ranges.
"""

def __init__(self,
min_width: PositiveInt = 1,
max_width: PositiveInt = sys.maxsize,
min_height: PositiveInt = 1,
max_height: PositiveInt = sys.maxsize,
any_or_all: str = 'any',
*args,
**kwargs):
"""
Initialization method.
:param min_width: The min width to keep samples.
:param max_width: The max width to keep samples.
:param min_height: The min height to keep samples.
:param max_height: The max height to keep samples.
:param any_or_all: keep this sample with 'any' or 'all' strategy of
all images. 'any': keep this sample if any images meet the
condition. 'all': keep this sample only if all images meet the
condition.
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)
self.min_width = min_width
self.max_width = max_width
self.min_height = min_height
self.max_height = max_height
if any_or_all not in ['any', 'all']:
raise ValueError(f'Keep strategy [{any_or_all}] is not supported. '
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')

def compute_stats(self, sample, context=False):
# check if it's computed already
if StatsKeys.image_width in sample[Fields.stats] \
and StatsKeys.image_height in sample[Fields.stats]:
return sample

# there is no image in this sample
if self.image_key not in sample or not sample[self.image_key]:
sample[Fields.stats][StatsKeys.image_width] = np.array(
[], dtype=np.int64)
sample[Fields.stats][StatsKeys.image_height] = np.array(
[], dtype=np.int64)
return sample

# load images
loaded_image_keys = sample[self.image_key]
images = {}
for loaded_image_key in loaded_image_keys:
if context and loaded_image_key in sample[Fields.context]:
# load from context
images[loaded_image_key] = sample[
Fields.context][loaded_image_key]
else:
if loaded_image_key not in images:
# avoid load the same images
image = load_image(loaded_image_key)
images[loaded_image_key] = image
if context:
# store the image data into context
sample[Fields.context][loaded_image_key] = image

# get width and height for each image
whs = {key: (images[key].width, images[key].height) for key in images}
sample[Fields.stats][StatsKeys.image_width] = [
whs[key][0] for key in loaded_image_keys
]
sample[Fields.stats][StatsKeys.image_height] = [
whs[key][1] for key in loaded_image_keys
]
return sample

def process(self, sample):
ws = sample[Fields.stats][StatsKeys.image_width]
hs = sample[Fields.stats][StatsKeys.image_height]
if len(ws) <= 0:
return True
keep_bools = np.array([
self.min_width <= w <= self.max_width
and self.min_height <= h <= self.max_height
for w, h in zip(ws, hs)
])

# different strategies
if self.any:
return keep_bools.any()
else:
return keep_bools.all()
2 changes: 1 addition & 1 deletion data_juicer/utils/compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def __init__(self, compressor_format: str = 'zstd'):
self.compressor_extension = '.' + compressor_format
self.compress_manager = CompressManager(
compressor_format=compressor_format)
self.pattern = re.compile('_\d{5}_of_') # noqa W605
self.pattern = re.compile(r'_\d{5}_of_')

def _get_raw_filename(self, filename: Union[Path, str]):
"""
Expand Down
2 changes: 2 additions & 0 deletions data_juicer/utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class StatsKeys(object):

# image
aspect_ratios = 'aspect_ratios'
image_width = 'image_width'
image_height = 'image_height'
image_sizes = 'image_sizes'

# multimodal
Expand Down
3 changes: 2 additions & 1 deletion demos/overview_scan/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
|-----------------------------------|:------:|-------------------------------------------------|
| Formatter | 7 | Discovers, loads, and canonicalizes source data |
| Mapper | 21 | Edits and transforms samples |
| Filter | 19 | Filters out low-quality samples |
| Filter | 20 | Filters out low-quality samples |
| Deduplicator | 4 | Detects and removes duplicate samples |
| Selector | 2 | Selects top samples based on ranking |
'''
Expand Down Expand Up @@ -143,6 +143,7 @@
| clip_similarity_filter | Multimodal | - | Keeps samples with similarity between text and images within the specified range |
| flagged_words_filter | General | en, zh | Keeps samples with flagged-word ratio below the specified threshold |
| image_aspect_ratio_filter | Image | - | Keeps samples contains images with aspect ratios within specific range |
| image_shape_filter | Image | - | Keeps samples contains images with widths and heights within specific ranges |
| image_size_filter | Image | - | Keeps samples contains images whose size in bytes are within specific range |
| language_id_score_filter | General | en, zh | Keeps samples of the specified language, judged by a predicted confidence score |
| maximum_line_length_filter | Code | en, zh | Keeps samples with maximum line length within the specified range |
Expand Down
3 changes: 2 additions & 1 deletion docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ The operators in Data-Juicer are categorized into 5 types.
|-----------------------------------|:------:|-------------------------------------------------|
| [ Formatter ]( #formatter ) | 7 | Discovers, loads, and canonicalizes source data |
| [ Mapper ]( #mapper ) | 21 | Edits and transforms samples |
| [ Filter ]( #filter ) | 19 | Filters out low-quality samples |
| [ Filter ]( #filter ) | 20 | Filters out low-quality samples |
| [ Deduplicator ]( #deduplicator ) | 4 | Detects and removes duplicate samples |
| [ Selector ]( #selector ) | 2 | Selects top samples based on ranking |

Expand Down Expand Up @@ -80,6 +80,7 @@ All the specific operators are listed below, each featured with several capabili
| clip_similarity_filter | Multimodal | - | Keeps samples with similarity between text and images within the specified range |
| flagged_words_filter | General | en, zh | Keeps samples with flagged-word ratio below the specified threshold |
| image_aspect_ratio_filter | Image | - | Keeps samples contains images with aspect ratios within specific range |
| image_shape_filter | Image | - | Keeps samples contains images with widths and heights within specific ranges |
| image_size_filter | Image | - | Keeps samples contains images whose size in bytes are within specific range |
| language_id_score_filter | General | en, zh | Keeps samples of the specified language, judged by a predicted confidence score |
| maximum_line_length_filter | Code | en, zh | Keeps samples with maximum line length within the specified range |
Expand Down
3 changes: 2 additions & 1 deletion docs/Operators_ZH.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Data-Juicer 中的算子分为以下 5 种类型。
|------------------------------------|:--:|---------------|
| [ Formatter ]( #formatter ) | 7 | 发现、加载、规范化原始数据 |
| [ Mapper ]( #mapper ) | 21 | 对数据样本进行编辑和转换 |
| [ Filter ]( #filter ) | 19 | 过滤低质量样本 |
| [ Filter ]( #filter ) | 20 | 过滤低质量样本 |
| [ Deduplicator ]( #deduplicator ) | 4 | 识别、删除重复样本 |
| [ Selector ]( #selector ) | 2 | 基于排序选取高质量样本 |

Expand Down Expand Up @@ -77,6 +77,7 @@ Data-Juicer 中的算子分为以下 5 种类型。
| clip_similarity_filter | Multimodal | - | 保留文本图像相似度在指定范围内的样本 |
| flagged_words_filter | General | en, zh | 保留使标记字比率保持在指定阈值以下的样本 |
| image_aspect_ratio_filter | Image | - | 保留样本中包含的图片的宽高比在指定范围内的样本 |
| image_shape_filter | Image | - | 保留样本中包含的图片的形状(即宽和高)在指定范围内的样本 |
| image_size_filter | Image | - | 保留样本中包含的图片的大小(bytes)在指定范围内的样本 |
| language_id_score_filter | General | en, zh | 保留特定语言的样本,通过预测的置信度得分来判断 |
| maximum_line_length_filter | Code | en, zh | 保留最大行长度在指定范围内的样本 |
Expand Down
1 change: 1 addition & 0 deletions environments/minimal_requires.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
fsspec==2023.3.0
pyarrow<=13.0.0
pandas==2.0.0
datasets==2.11.0
loguru
Expand Down
127 changes: 127 additions & 0 deletions tests/ops/filter/test_image_shape_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import os
import unittest
import numpy as np
import PIL.Image

from datasets import Dataset, Image

from data_juicer.ops.filter.image_shape_filter import ImageShapeFilter
from data_juicer.utils.constant import Fields


class ImageShapeFilterTest(unittest.TestCase):

data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
'..', 'data')
img1_path = os.path.join(data_path, 'img1.png')
img2_path = os.path.join(data_path, 'img2.jpg')
img3_path = os.path.join(data_path, 'img3.jpg')

def _run_image_shape_filter(self,
dataset: Dataset,
target_list,
op):
if Fields.stats not in dataset.features:
dataset = dataset.add_column(name=Fields.stats,
column=[{}] * dataset.num_rows)
dataset = dataset.map(op.compute_stats)
dataset = dataset.filter(op.process)
dataset = dataset.select_columns(column_names=[op.image_key])
res_list = dataset.to_list()
self.assertEqual(res_list, target_list)

def test_filter1(self):

ds_list = [{
'images': [self.img1_path]
}, {
'images': [self.img2_path]
}, {
'images': [self.img3_path]
}]
tgt_list = [{
'images': [self.img2_path]
}]
dataset = Dataset.from_list(ds_list)
op = ImageShapeFilter(min_width=400,
min_height=400)
self._run_image_shape_filter(dataset, tgt_list, op)

def test_filter2(self):

ds_list = [{
'images': [self.img1_path]
}, {
'images': [self.img2_path]
}, {
'images': [self.img3_path]
}]
tgt_list = [{
'images': [self.img1_path]
}, {
'images': [self.img3_path]
}]
dataset = Dataset.from_list(ds_list)
op = ImageShapeFilter(max_width=500,
max_height=500)
self._run_image_shape_filter(dataset, tgt_list, op)

def test_filter3(self):

ds_list = [{
'images': [self.img1_path]
}, {
'images': [self.img2_path]
}, {
'images': [self.img3_path]
}]
tgt_list = [{
'images': [self.img1_path]
}, {
'images': [self.img2_path]
}, {
'images': [self.img3_path]
}]
dataset = Dataset.from_list(ds_list)
op = ImageShapeFilter()
self._run_image_shape_filter(dataset, tgt_list, op)

def test_any(self):

ds_list = [{
'images': [self.img1_path, self.img2_path]
}, {
'images': [self.img2_path, self.img3_path]
}, {
'images': [self.img1_path, self.img3_path]
}]
tgt_list = [{
'images': [self.img1_path, self.img2_path]
}, {
'images': [self.img2_path, self.img3_path]
}]
dataset = Dataset.from_list(ds_list)
op = ImageShapeFilter(min_width=400,
min_height=400,
any_or_all='any')
self._run_image_shape_filter(dataset, tgt_list, op)

def test_all(self):

ds_list = [{
'images': [self.img1_path, self.img2_path]
}, {
'images': [self.img2_path, self.img3_path]
}, {
'images': [self.img1_path, self.img3_path]
}]
tgt_list = []
dataset = Dataset.from_list(ds_list)
op = ImageShapeFilter(min_width=400,
min_height=400,
any_or_all='all')
self._run_image_shape_filter(dataset, tgt_list, op)


if __name__ == '__main__':
unittest.main()

0 comments on commit de984f3

Please sign in to comment.