Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize fa #2

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
39 changes: 25 additions & 14 deletions face_alignment/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from .utils import *


class LandmarksType(IntEnum):
"""Enum class defining the type of landmarks to detect.

Expand All @@ -27,8 +26,8 @@ class NetworkSize(IntEnum):


class FaceAlignment:
def __init__(self, landmarks_type, face_align_model_path, depth_pred_model_path, network_size=NetworkSize.LARGE,
device='cuda', flip_input=False, face_detector='sfd', face_detector_kwargs=None, verbose=False):
def __init__(self, landmarks_type, face_align_model_path, depth_pred_model_path=None, network_size=NetworkSize.LARGE,
device='cuda', flip_input=False, face_detector=None, face_detector_kwargs=None, verbose=False):
self.device = device
self.flip_input = flip_input
self.landmarks_type = landmarks_type
Expand All @@ -51,10 +50,13 @@ def __init__(self, landmarks_type, face_align_model_path, depth_pred_model_path,
torch.backends.cudnn.benchmark = True

# Get the face detector
face_detector_module = __import__('face_alignment.detection.' + face_detector,
globals(), locals(), [face_detector], 0)
face_detector_kwargs = face_detector_kwargs or {}
self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose, **face_detector_kwargs)
if face_detector:
face_detector_module = __import__('face_alignment.detection.' + face_detector,
globals(), locals(), [face_detector], 0)
face_detector_kwargs = face_detector_kwargs or {}
self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose, **face_detector_kwargs)
else:
self.face_detector = None

# Initialise the face alignemnt networks
if landmarks_type == LandmarksType._2D:
Expand Down Expand Up @@ -116,7 +118,10 @@ def get_landmarks_from_image(self, image_or_path, detected_faces=None, return_bb
image = get_image(image_or_path)

if detected_faces is None:
detected_faces = self.face_detector.detect_from_image(image.copy())
try:
detected_faces = self.face_detector.detect_from_image(image.copy())
except:
raise Exception(f"A list of bounding boxes or a face_detector method is needed.")

if len(detected_faces) == 0:
warnings.warn("No faces were detected.")
Expand All @@ -127,23 +132,28 @@ def get_landmarks_from_image(self, image_or_path, detected_faces=None, return_bb

landmarks = []
landmarks_scores = []
for i, d in enumerate(detected_faces):
for i, d in enumerate(detected_faces):
center = torch.tensor(
[d[2] - (d[2] - d[0]) / 2.0, d[3] - (d[3] - d[1]) / 2.0])
center[1] = center[1] - (d[3] - d[1]) * 0.12
scale = (d[2] - d[0] + d[3] - d[1]) / self.face_detector.reference_scale
if self.face_detector:
center[1] = center[1] - (d[3] - d[1]) * 0.12
scale = (d[2] - d[0] + d[3] - d[1]) / self.face_detector.reference_scale
else:
scale = (d[2]-d[0])/200
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a comment explaining why 200? like "everywhere the scale is multiplied by 200, we don't know why, but if we divide it here by 200 it works"


inp = crop(image, center, scale)
inp = torch.from_numpy(inp.transpose(
(2, 0, 1))).float()

inp = inp.to(self.device)
inp.div_(255.0).unsqueeze_(0)

torch._C._set_graph_executor_optimize(False)
cristinapunti marked this conversation as resolved.
Show resolved Hide resolved
out = self.face_alignment_net(inp).detach()
if self.flip_input:
out += flip(self.face_alignment_net(flip(inp)).detach(), is_label=True)
out = out.cpu().numpy()
torch._C._set_graph_executor_optimize(True)

pts, pts_img, scores = get_preds_fromhm(out, center.numpy(), scale)
pts, pts_img = torch.from_numpy(pts), torch.from_numpy(pts_img)
Expand All @@ -156,18 +166,19 @@ def get_landmarks_from_image(self, image_or_path, detected_faces=None, return_bb
if pts[i, 0] > 0 and pts[i, 1] > 0:
heatmaps[i] = draw_gaussian(
heatmaps[i], pts[i], 2)

heatmaps = torch.from_numpy(
heatmaps).unsqueeze_(0)

heatmaps = heatmaps.to(self.device)
depth_pred = self.depth_prediciton_net(
torch.cat((inp, heatmaps), 1)).data.cpu().view(68, 1)
pts_img = torch.cat(
(pts_img, depth_pred * (1.0 / (256.0 / (200.0 * scale)))), 1)
(pts_img, depth_pred * (1.0 / (256.0 / (200.0 * scale)))), 1)

landmarks.append(pts_img.numpy())
landmarks_scores.append(scores)

if not return_bboxes:
detected_faces = None
if not return_landmark_score:
Expand Down
4 changes: 1 addition & 3 deletions face_alignment/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ def crop(image, center, scale, resolution=256.0):
return newImg


@jit(nopython=True)
def transform_np(point, center, scale, resolution, invert=False):
"""Generate and affine transformation matrix.

Expand Down Expand Up @@ -203,7 +202,6 @@ def get_preds_fromhm(hm, center=None, scale=None):
return preds, preds_orig, scores


@jit(nopython=True)
def _get_preds_fromhm(hm, idx, center=None, scale=None):
"""Obtain (x,y) coordinates given a set of N heatmaps and the
coresponding locations of the maximums. If the center
Expand Down Expand Up @@ -234,7 +232,7 @@ def _get_preds_fromhm(hm, idx, center=None, scale=None):
preds[i, j] += np.sign(diff) * 0.25

preds -= 0.5

preds_orig = np.zeros_like(preds)
if center is not None and scale is not None:
for i in range(B):
Expand Down