Skip to content

Commit

Permalink
Switch face detection to OpenCV (#320)
Browse files Browse the repository at this point in the history
  • Loading branch information
drcege authored Aug 22, 2024
1 parent b154f4d commit f7103fe
Show file tree
Hide file tree
Showing 10 changed files with 132 additions and 63 deletions.
3 changes: 3 additions & 0 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ process:
hf_img2seq: 'Salesforce/blip2-opt-2.7b' # model name on huggingface to generate caption if caption_key is null
mem_required: '8GB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrains the maximum number of processes that can be launched
- image_face_blur_mapper: # blur faces detected in images
cv_classifier: '' # OpenCV classifier path for face detection. By default, we will use 'haarcascade_frontalface_alt.xml'.
blur_type: 'gaussian' # type of blur kernel, including ['mean', 'box', 'gaussian']
radius: 2 # radius of blur kernel
- nlpaug_en_mapper: # simply augment texts in English based on the nlpaug library
Expand Down Expand Up @@ -194,6 +195,7 @@ process:
vertical_flip: false # flip frame image vertically (top to bottom).
mem_required: '20GB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrains the maximum number of processes that can be launched
- video_face_blur_mapper: # blur faces detected in videos
cv_classifier: '' # OpenCV classifier path for face detection. By default, we will use 'haarcascade_frontalface_alt.xml'.
blur_type: 'gaussian' # type of blur kernel, including ['mean', 'box', 'gaussian']
radius: 2 # radius of blur kernel
- video_ffmpeg_wrapped_mapper: # simple wrapper for FFmpeg video filters
Expand Down Expand Up @@ -278,6 +280,7 @@ process:
max_ratio: 3.0 # the max aspect ratio of filter range
any_or_all: any # keep this sample when any/all images meet the filter condition
- image_face_ratio_filter: # filter samples according to the face area ratios in images (r=face_area/image_area). If multiple faces are available, we use the largest one.
cv_classifier: '' # OpenCV classifier path for face detection. By default, we will use 'haarcascade_frontalface_alt.xml'.
min_ratio: 0.0 # the min face area ratio of filter range
max_ratio: 0.4 # the max face area ratio of filter range
- image_nsfw_filter: # filter samples according to the nsfw scores of images in them
Expand Down
47 changes: 29 additions & 18 deletions data_juicer/ops/filter/image_face_ratio_filter.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,40 @@
import os

import numpy as np
from jsonargparse.typing import ClosedUnitInterval
from loguru import logger

from data_juicer.utils.availability_utils import AvailabilityChecking
from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.mm_utils import (load_data_with_context, load_image,
pil_to_opencv)
from data_juicer.utils.mm_utils import (detect_faces, load_data_with_context,
load_image)
from data_juicer.utils.model_utils import get_model, prepare_model

from ..base_op import OPERATORS, Filter
from ..base_op import OPERATORS, UNFORKABLE, Filter
from ..op_fusion import LOADED_IMAGES

OP_NAME = 'image_face_ratio_filter'

with AvailabilityChecking(['dlib'], OP_NAME):
import dlib
with AvailabilityChecking(['opencv-python'], OP_NAME):
import cv2


@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
@LOADED_IMAGES.register_module(OP_NAME)
class ImageFaceRatioFilter(Filter):
"""Filter to keep samples with face area ratios within a specific range.
"""

_default_kwargs = {'upsample_num_times': 0}
_default_kwargs = {
'scaleFactor': 1.1,
'minNeighbors': 3,
'minSize': None,
'maxSize': None,
}

def __init__(self,
cv_classifier='',
min_ratio: ClosedUnitInterval = 0.0,
max_ratio: ClosedUnitInterval = 0.4,
any_or_all: str = 'any',
Expand All @@ -33,6 +43,8 @@ def __init__(self,
"""
Initialization method.
:param cv_classifier: OpenCV classifier path for face detection.
By default, we will use 'haarcascade_frontalface_alt.xml'.
:param min_ratio: Min ratio for the largest face area in an image.
:param max_ratio: Max ratio for the largest face area in an image.
:param any_or_all: Keep this sample with 'any' or 'all' strategy of
Expand All @@ -43,6 +55,10 @@ def __init__(self,
:param kwargs: Extra keyword arguments.
"""
super().__init__(*args, **kwargs)

if cv_classifier == '':
cv_classifier = os.path.join(cv2.data.haarcascades,
'haarcascade_frontalface_alt.xml')
self.min_ratio = min_ratio
self.max_ratio = max_ratio

Expand All @@ -56,8 +72,8 @@ def __init__(self,
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')

# Initialize face detector
self.detector = dlib.get_frontal_face_detector()
self.model_key = prepare_model(model_type='opencv_classifier',
model_path=cv_classifier)

def compute_stats(self, sample, context=False):
# check if it's computed already
Expand All @@ -75,25 +91,20 @@ def compute_stats(self, sample, context=False):
sample, images = load_data_with_context(sample, context,
loaded_image_keys, load_image)

model = get_model(self.model_key)

# detect faces
face_detections = {}
for key, image in images.items():
img = pil_to_opencv(image)
dets = self.detector(img, **self.extra_kwargs)
face_detections[key] = [[
max(det.left(), 0),
max(det.top(), 0),
min(det.right(), image.width),
min(det.bottom(), image.height)
] for det in dets]
face_detections[key] = detect_faces(image, model,
**self.extra_kwargs)
logger.debug(f'detections: {face_detections}')

# compute face area ratios for each image considering the largest face
face_area_ratios = {}
for key, dets in face_detections.items():
image_area = images[key].width * images[key].height
face_area_ratios[key] = max([(x2 - x1) * (y2 - y1)
for x1, y1, x2, y2 in dets],
face_area_ratios[key] = max([w * h for _, _, w, h in dets],
default=0.0) / image_area
logger.debug(f'ratios: {face_area_ratios}')

Expand Down
50 changes: 31 additions & 19 deletions data_juicer/ops/mapper/image_face_blur_mapper.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,49 @@
import os

from loguru import logger

from data_juicer.utils.availability_utils import AvailabilityChecking
from data_juicer.utils.constant import Fields
from data_juicer.utils.file_utils import transfer_filename
from data_juicer.utils.mm_utils import (load_data_with_context, load_image,
pil_to_opencv)
from data_juicer.utils.mm_utils import (detect_faces, load_data_with_context,
load_image)
from data_juicer.utils.model_utils import get_model, prepare_model

from ..base_op import OPERATORS, Mapper
from ..base_op import OPERATORS, UNFORKABLE, Mapper
from ..op_fusion import LOADED_IMAGES

OP_NAME = 'image_face_blur_mapper'

with AvailabilityChecking(['dlib', 'Pillow'], OP_NAME):
import dlib
with AvailabilityChecking(['opencv-python', 'Pillow'], OP_NAME):
import cv2
from PIL import ImageFilter


@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
@LOADED_IMAGES.register_module(OP_NAME)
class ImageFaceBlurMapper(Mapper):
"""Mapper to blur faces detected in images.
"""

_default_kwargs = {'upsample_num_times': 0}
_default_kwargs = {
'scaleFactor': 1.1,
'minNeighbors': 3,
'minSize': None,
'maxSize': None,
}

def __init__(self,
cv_classifier='',
blur_type: str = 'gaussian',
radius: float = 2,
*args,
**kwargs):
"""
Initialization method.
:param cv_classifier: OpenCV classifier path for face detection.
By default, we will use 'haarcascade_frontalface_alt.xml'.
:param blur_type: Type of blur kernel, including
['mean', 'box', 'gaussian'].
:param radius: Radius of blur kernel.
Expand All @@ -41,6 +53,9 @@ def __init__(self,
super().__init__(*args, **kwargs)
self._init_parameters = self.remove_extra_parameters(locals())

if cv_classifier == '':
cv_classifier = os.path.join(cv2.data.haarcascades,
'haarcascade_frontalface_alt.xml')
if blur_type not in ['mean', 'box', 'gaussian']:
raise ValueError(
f'Blur_type [{blur_type}] is not supported. '
Expand All @@ -63,8 +78,8 @@ def __init__(self,
if key in self.extra_kwargs:
self.extra_kwargs[key] = kwargs[key]

# Initialize face detector
self.detector = dlib.get_frontal_face_detector()
self.model_key = prepare_model(model_type='opencv_classifier',
model_path=cv_classifier)

def process(self, sample, context=False):
# there is no image in this sample
Expand All @@ -80,17 +95,13 @@ def process(self, sample, context=False):
sample, images = load_data_with_context(sample, context,
loaded_image_keys, load_image)

model = get_model(self.model_key)

# detect faces
face_detections = {}
for key, image in images.items():
img = pil_to_opencv(image)
dets = self.detector(img, **self.extra_kwargs)
face_detections[key] = [[
max(det.left(), 0),
max(det.top(), 0),
min(det.right(), image.width),
min(det.bottom(), image.height)
] for det in dets]
face_detections[key] = detect_faces(image, model,
**self.extra_kwargs)
logger.debug(f'detections: {face_detections}')

# blur face regions
Expand All @@ -100,9 +111,10 @@ def process(self, sample, context=False):
# only blur when detected face
if len(dets) > 0:
blured_image = image.copy()
for det in dets:
blured_roi = image.crop(det).filter(self.blur)
blured_image.paste(blured_roi, det)
for (x, y, w, h) in dets:
box = (x, y, x + w, y + h)
blured_roi = image.crop(box).filter(self.blur)
blured_image.paste(blured_roi, box)
blured_image_key = transfer_filename(key, OP_NAME,
**self._init_parameters)
blured_image.save(blured_image_key)
Expand Down
61 changes: 37 additions & 24 deletions data_juicer/ops/mapper/video_face_blur_mapper.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,50 @@
import os

import av

from data_juicer.utils.availability_utils import AvailabilityChecking
from data_juicer.utils.constant import Fields
from data_juicer.utils.file_utils import transfer_filename
from data_juicer.utils.mm_utils import (close_video, load_data_with_context,
load_video, pil_to_opencv,
from data_juicer.utils.mm_utils import (close_video, detect_faces,
load_data_with_context, load_video,
process_each_frame)
from data_juicer.utils.model_utils import get_model, prepare_model

from ..base_op import OPERATORS, Mapper
from ..base_op import OPERATORS, UNFORKABLE, Mapper
from ..op_fusion import LOADED_VIDEOS

OP_NAME = 'video_face_blur_mapper'

with AvailabilityChecking(['dlib', 'Pillow'], OP_NAME):
import dlib
with AvailabilityChecking(['opencv-python', 'Pillow'], OP_NAME):
import cv2
from PIL import ImageFilter


@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
@LOADED_VIDEOS.register_module(OP_NAME)
class VideoFaceBlurMapper(Mapper):
"""Mapper to blur faces detected in videos.
"""

_default_kwargs = {'upsample_num_times': 0}
_default_kwargs = {
'scaleFactor': 1.1,
'minNeighbors': 3,
'minSize': None,
'maxSize': None,
}

def __init__(self,
cv_classifier='',
blur_type: str = 'gaussian',
radius: float = 2,
*args,
**kwargs):
"""
Initialization method.
:param cv_classifier: OpenCV classifier path for face detection.
By default, we will use 'haarcascade_frontalface_alt.xml'.
:param blur_type: Type of blur kernel, including
['mean', 'box', 'gaussian'].
:param radius: Radius of blur kernel.
Expand All @@ -42,6 +54,9 @@ def __init__(self,
super().__init__(*args, **kwargs)
self._init_parameters = self.remove_extra_parameters(locals())

if cv_classifier == '':
cv_classifier = os.path.join(cv2.data.haarcascades,
'haarcascade_frontalface_alt.xml')
if blur_type not in ['mean', 'box', 'gaussian']:
raise ValueError(
f'Blur_type [{blur_type}] is not supported. '
Expand All @@ -64,8 +79,8 @@ def __init__(self,
if key in self.extra_kwargs:
self.extra_kwargs[key] = kwargs[key]

# Initialize face detector
self.detector = dlib.get_frontal_face_detector()
self.model_key = prepare_model(model_type='opencv_classifier',
model_path=cv_classifier)

def process(self, sample, context=False):
# there is no video in this sample
Expand All @@ -80,6 +95,19 @@ def process(self, sample, context=False):
sample, videos = load_data_with_context(sample, context,
loaded_video_keys, load_video)

model = get_model(self.model_key)

def _blur_func(frame):
image = frame.to_image()
dets = detect_faces(image, model, **self.extra_kwargs)
if len(dets) > 0:
for (x, y, w, h) in dets:
box = (x, y, x + w, y + h)
blured_roi = image.crop(box).filter(self.blur)
image.paste(blured_roi, box)
frame = av.VideoFrame.from_image(image)
return frame

processed_video_keys = {}
for video_key in loaded_video_keys:
# skip duplicate
Expand All @@ -90,7 +118,7 @@ def process(self, sample, context=False):
blured_video_key = transfer_filename(video_key, OP_NAME,
**self._init_parameters)
output_video_key = process_each_frame(video, blured_video_key,
self._blur_face)
_blur_func)
processed_video_keys[video_key] = output_video_key

if not context:
Expand All @@ -106,18 +134,3 @@ def process(self, sample, context=False):
processed_video_keys[key] for key in loaded_video_keys
]
return sample

def _blur_face(self, frame):
image = frame.to_image()
img = pil_to_opencv(image)
dets = self.detector(img, **self.extra_kwargs)
if len(dets) > 0:
for det in dets:
x1 = max(det.left(), 0)
y1 = max(det.top(), 0)
x2 = min(det.right(), image.width)
y2 = min(det.bottom(), image.height)
blured_roi = image.crop((x1, y1, x2, y2)).filter(self.blur)
image.paste(blured_roi, (x1, y1, x2, y2))
frame = av.VideoFrame.from_image(image)
return frame
16 changes: 16 additions & 0 deletions data_juicer/utils/mm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,22 @@ def pil_to_opencv(pil_image):
return opencv_image


def detect_faces(image, detector, **extra_kwargs):
import cv2

img = pil_to_opencv(image)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
dets = detector.detectMultiScale(gray, **extra_kwargs)
rectified_dets = []
for (x, y, w, h) in dets:
x = max(x, 0)
y = max(y, 0)
w = min(w, image.width - x)
h = min(h, image.height - y)
rectified_dets.append([x, y, w, h])
return rectified_dets


def get_file_size(path):
import os
return os.path.getsize(path)
Expand Down
Loading

0 comments on commit f7103fe

Please sign in to comment.