Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
refine body_uv_rcnn
Browse files Browse the repository at this point in the history
  • Loading branch information
Johnqczhang committed Mar 11, 2019
1 parent 8a7f239 commit 723ab20
Show file tree
Hide file tree
Showing 13 changed files with 575 additions and 592 deletions.
3 changes: 3 additions & 0 deletions detectron/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,9 @@
# Number of patches in the dataset
__C.BODY_UV_RCNN.NUM_PATCHES = -1

# Number of semantic parts used to sample annotation points
__C.BODY_UV_RCNN.NUM_SEMANTIC_PARTS = 14

# Number of stacked Conv layers in body UV head
__C.BODY_UV_RCNN.NUM_STACKED_CONVS = 8
# Dimension of the hidden representation output by the body UV head
Expand Down
6 changes: 3 additions & 3 deletions detectron/core/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,16 +948,16 @@ def im_detect_body_uv(model, im_scale, boxes):
# Removed squeeze calls due to singleton dimension issues
CurAnnIndex = np.argmax(CurAnnIndex, axis=0)
CurIndex_UV = np.argmax(CurIndex_UV, axis=0)
CurIndex_UV = CurIndex_UV * (CurAnnIndex>0).astype(np.float32)
CurIndex_UV = CurIndex_UV * (CurAnnIndex > 0).astype(np.float32)

output = np.zeros([3, int(by), int(bx)], dtype=np.float32)
output[0] = CurIndex_UV

for part_id in range(1, K):
CurrentU = CurU_uv[part_id]
CurrentV = CurV_uv[part_id]
output[1, CurIndex_UV==part_id] = CurrentU[CurIndex_UV==part_id]
output[2, CurIndex_UV==part_id] = CurrentV[CurIndex_UV==part_id]
output[1, CurIndex_UV == part_id] = CurrentU[CurIndex_UV == part_id]
output[2, CurIndex_UV == part_id] = CurrentV[CurIndex_UV == part_id]
outputs.append(output)

num_classes = cfg.MODEL.NUM_CLASSES
Expand Down
8 changes: 4 additions & 4 deletions detectron/datasets/json_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def _prep_roidb_entry(self, entry):
(0, 3, self.num_keypoints), dtype=np.int32
)
if cfg.MODEL.BODY_UV_ON:
entry['ignore_UV_body'] = np.empty((0), dtype=np.bool)
entry['ignore_UV_body'] = np.empty((0), dtype=np.bool)
# entry['Box_image_links_body'] = []
# Remove unwanted fields that come from the json file (if they exist)
for k in ['date_captured', 'url', 'license', 'file_name']:
Expand Down Expand Up @@ -200,7 +200,7 @@ def _add_gt_annotations(self, entry):
valid_objs.append(obj)
valid_segms.append(obj['segmentation'])
###
if 'dp_x' in obj.keys():
if 'dp_x' in obj:
valid_dp_x.append(obj['dp_x'])
valid_dp_y.append(obj['dp_y'])
valid_dp_I.append(obj['dp_I'])
Expand All @@ -216,7 +216,7 @@ def _add_gt_annotations(self, entry):
valid_dp_masks.append([])
###
num_valid_objs = len(valid_objs)
##

boxes = np.zeros((num_valid_objs, 4), dtype=entry['boxes'].dtype)
gt_classes = np.zeros((num_valid_objs), dtype=entry['gt_classes'].dtype)
gt_overlaps = np.zeros(
Expand All @@ -234,7 +234,7 @@ def _add_gt_annotations(self, entry):
dtype=entry['gt_keypoints'].dtype
)
if cfg.MODEL.BODY_UV_ON:
ignore_UV_body = np.zeros((num_valid_objs))
ignore_UV_body = np.zeros((num_valid_objs), dtype=entry['ignore_UV_body'].dtype)
#Box_image_body = [None]*num_valid_objs

im_has_visible_keypoints = False
Expand Down
2 changes: 1 addition & 1 deletion detectron/datasets/roidb.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def is_valid(entry):
if cfg.MODEL.BODY_UV_ON and cfg.BODY_UV_RCNN.BODY_UV_IMS:
# Exclude images with no body uv
valid = valid and entry['has_body_uv']
return valid
return valid

num = len(roidb)
filtered_roidb = [entry for entry in roidb if is_valid(entry)]
Expand Down
260 changes: 128 additions & 132 deletions detectron/modeling/body_uv_rcnn_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,150 +3,146 @@
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
##############################################################################

"""Various network "heads" for dense human pose estimation in DensePose.
The design is as follows:
... -> RoI ----\ /-> mask output -> cls loss
-> RoIFeatureXform -> body UV head -> patch output -> cls loss
... -> Feature / \-> UV output -> reg loss
Map
The body UV head produces a feature representation of the RoI for the purpose
of dense semantic mask prediction, body surface patch prediction and body UV
coordinates regression. The body UV output module converts the feature
representation into heatmaps for dense mask, patch index and UV coordinates.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from caffe2.python import core

from detectron.core.config import cfg

from detectron.utils.c2 import const_fill
import detectron.modeling.ResNet as ResNet
import detectron.utils.blob as blob_utils

# ---------------------------------------------------------------------------- #
# Body UV heads
# Body UV outputs and losses
# ---------------------------------------------------------------------------- #

def add_body_uv_outputs(model, blob_in, dim, pref=''):
####
model.ConvTranspose(blob_in, 'AnnIndex_lowres'+pref, dim, 15,cfg.BODY_UV_RCNN.DECONV_KERNEL, pad=int(cfg.BODY_UV_RCNN.DECONV_KERNEL / 2 - 1), stride=2, weight_init=(cfg.BODY_UV_RCNN.CONV_INIT, {'std': 0.001}), bias_init=('ConstantFill', {'value': 0.}))
####
model.ConvTranspose(blob_in, 'Index_UV_lowres'+pref, dim, cfg.BODY_UV_RCNN.NUM_PATCHES+1,cfg.BODY_UV_RCNN.DECONV_KERNEL, pad=int(cfg.BODY_UV_RCNN.DECONV_KERNEL / 2 - 1), stride=2, weight_init=(cfg.BODY_UV_RCNN.CONV_INIT, {'std': 0.001}), bias_init=('ConstantFill', {'value': 0.}))
####
model.ConvTranspose(
blob_in, 'U_lowres'+pref, dim, (cfg.BODY_UV_RCNN.NUM_PATCHES+1),
cfg.BODY_UV_RCNN.DECONV_KERNEL,
pad=int(cfg.BODY_UV_RCNN.DECONV_KERNEL / 2 - 1),
stride=2,
weight_init=(cfg.BODY_UV_RCNN.CONV_INIT, {'std': 0.001}),
bias_init=('ConstantFill', {'value': 0.}))
#####
model.ConvTranspose(
blob_in, 'V_lowres'+pref, dim, cfg.BODY_UV_RCNN.NUM_PATCHES+1,
def add_body_uv_outputs(model, blob_in, dim):
"""Add DensePose body UV specific outputs: heatmaps of dense mask, patch index
and patch-specific UV coordinates. All dense masks are mapped to labels in
[0, ... S] for S semantically meaningful body parts.
"""
# Apply ConvTranspose to the feature representation; results in 2x upsampling
for name in ['AnnIndex', 'Index_UV', 'U', 'V']:
if name == 'AnnIndex':
dim_out = cfg.BODY_UV_RCNN.NUM_SEMANTIC_PARTS + 1
else:
dim_out = cfg.BODY_UV_RCNN.NUM_PATCHES + 1
model.ConvTranspose(
blob_in,
name + '_lowres',
dim,
dim_out,
cfg.BODY_UV_RCNN.DECONV_KERNEL,
pad=int(cfg.BODY_UV_RCNN.DECONV_KERNEL / 2 - 1),
stride=2,
weight_init=(cfg.BODY_UV_RCNN.CONV_INIT, {'std': 0.001}),
bias_init=('ConstantFill', {'value': 0.}))
####
blob_Ann_Index = model.BilinearInterpolation('AnnIndex_lowres'+pref, 'AnnIndex'+pref, cfg.BODY_UV_RCNN.NUM_PATCHES+1 , cfg.BODY_UV_RCNN.NUM_PATCHES+1, cfg.BODY_UV_RCNN.UP_SCALE)
blob_Index = model.BilinearInterpolation('Index_UV_lowres'+pref, 'Index_UV'+pref, cfg.BODY_UV_RCNN.NUM_PATCHES+1 , cfg.BODY_UV_RCNN.NUM_PATCHES+1, cfg.BODY_UV_RCNN.UP_SCALE)
blob_U = model.BilinearInterpolation('U_lowres'+pref, 'U_estimated'+pref, cfg.BODY_UV_RCNN.NUM_PATCHES+1 , cfg.BODY_UV_RCNN.NUM_PATCHES+1, cfg.BODY_UV_RCNN.UP_SCALE)
blob_V = model.BilinearInterpolation('V_lowres'+pref, 'V_estimated'+pref, cfg.BODY_UV_RCNN.NUM_PATCHES+1 , cfg.BODY_UV_RCNN.NUM_PATCHES+1, cfg.BODY_UV_RCNN.UP_SCALE)
###
return blob_U,blob_V,blob_Index,blob_Ann_Index


def add_body_uv_losses(model, pref=''):

## Reshape for GT blobs.
model.net.Reshape( ['body_uv_X_points'], ['X_points_reshaped'+pref, 'X_points_shape'+pref], shape=( -1 ,1 ) )
model.net.Reshape( ['body_uv_Y_points'], ['Y_points_reshaped'+pref, 'Y_points_shape'+pref], shape=( -1 ,1 ) )
model.net.Reshape( ['body_uv_I_points'], ['I_points_reshaped'+pref, 'I_points_shape'+pref], shape=( -1 ,1 ) )
model.net.Reshape( ['body_uv_Ind_points'], ['Ind_points_reshaped'+pref, 'Ind_points_shape'+pref], shape=( -1 ,1 ) )
## Concat Ind,x,y to get Coordinates blob.
model.net.Concat( ['Ind_points_reshaped'+pref,'X_points_reshaped'+pref, \
'Y_points_reshaped'+pref],['Coordinates'+pref,'Coordinate_Shapes'+pref ], axis = 1 )
##
### Now reshape UV blobs, such that they are 1x1x(196*NumSamples)xNUM_PATCHES
## U blob to
##
model.net.Reshape(['body_uv_U_points'], \
['U_points_reshaped'+pref, 'U_points_old_shape'+pref],\
shape=(-1,cfg.BODY_UV_RCNN.NUM_PATCHES+1,196))
model.net.Transpose(['U_points_reshaped'+pref] ,['U_points_reshaped_transpose'+pref],axes=(0,2,1) )
model.net.Reshape(['U_points_reshaped_transpose'+pref], \
['U_points'+pref, 'U_points_old_shape2'+pref], \
shape=(1,1,-1,cfg.BODY_UV_RCNN.NUM_PATCHES+1))
## V blob
##
model.net.Reshape(['body_uv_V_points'], \
['V_points_reshaped'+pref, 'V_points_old_shape'+pref],\
shape=(-1,cfg.BODY_UV_RCNN.NUM_PATCHES+1,196))
model.net.Transpose(['V_points_reshaped'+pref] ,['V_points_reshaped_transpose'+pref],axes=(0,2,1) )
model.net.Reshape(['V_points_reshaped_transpose'+pref], \
['V_points'+pref, 'V_points_old_shape2'+pref], \
shape=(1,1,-1,cfg.BODY_UV_RCNN.NUM_PATCHES+1))
###
## UV weights blob
##
model.net.Reshape(['body_uv_point_weights'], \
['Uv_point_weights_reshaped'+pref, 'Uv_point_weights_old_shape'+pref],\
shape=(-1,cfg.BODY_UV_RCNN.NUM_PATCHES+1,196))
model.net.Transpose(['Uv_point_weights_reshaped'+pref] ,['Uv_point_weights_reshaped_transpose'+pref],axes=(0,2,1) )
model.net.Reshape(['Uv_point_weights_reshaped_transpose'+pref], \
['Uv_point_weights'+pref, 'Uv_point_weights_old_shape2'+pref], \
shape=(1,1,-1,cfg.BODY_UV_RCNN.NUM_PATCHES+1))

#####################
### Pool IUV for points via bilinear interpolation.
model.PoolPointsInterp(['U_estimated','Coordinates'+pref], ['interp_U'+pref])
model.PoolPointsInterp(['V_estimated','Coordinates'+pref], ['interp_V'+pref])
model.PoolPointsInterp(['Index_UV'+pref,'Coordinates'+pref], ['interp_Index_UV'+pref])

## Reshape interpolated UV coordinates to apply the loss.

model.net.Reshape(['interp_U'+pref], \
['interp_U_reshaped'+pref, 'interp_U_shape'+pref],\
shape=(1, 1, -1 , cfg.BODY_UV_RCNN.NUM_PATCHES+1))

model.net.Reshape(['interp_V'+pref], \
['interp_V_reshaped'+pref, 'interp_V_shape'+pref],\
shape=(1, 1, -1 , cfg.BODY_UV_RCNN.NUM_PATCHES+1))
###

### Do the actual labels here !!!!
model.net.Reshape( ['body_uv_ann_labels'], \
['body_uv_ann_labels_reshaped' +pref, 'body_uv_ann_labels_old_shape'+pref], \
shape=(-1, cfg.BODY_UV_RCNN.HEATMAP_SIZE , cfg.BODY_UV_RCNN.HEATMAP_SIZE))

model.net.Reshape( ['body_uv_ann_weights'], \
['body_uv_ann_weights_reshaped' +pref, 'body_uv_ann_weights_old_shape'+pref], \
shape=( -1 , cfg.BODY_UV_RCNN.HEATMAP_SIZE , cfg.BODY_UV_RCNN.HEATMAP_SIZE))
###
model.net.Cast( ['I_points_reshaped'+pref], ['I_points_reshaped_int'+pref], to=core.DataType.INT32)
### Now add the actual losses
## The mask segmentation loss (dense)
probs_seg_AnnIndex, loss_seg_AnnIndex = model.net.SpatialSoftmaxWithLoss( \
['AnnIndex'+pref, 'body_uv_ann_labels_reshaped'+pref,'body_uv_ann_weights_reshaped'+pref],\
['probs_seg_AnnIndex'+pref,'loss_seg_AnnIndex'+pref], \
scale=cfg.BODY_UV_RCNN.INDEX_WEIGHTS / cfg.NUM_GPUS)
## Point Patch Index Loss.
probs_IndexUVPoints, loss_IndexUVPoints = model.net.SoftmaxWithLoss(\
['interp_Index_UV'+pref,'I_points_reshaped_int'+pref],\
['probs_IndexUVPoints'+pref,'loss_IndexUVPoints'+pref], \
scale=cfg.BODY_UV_RCNN.PART_WEIGHTS / cfg.NUM_GPUS, spatial=0)
## U and V point losses.
loss_Upoints = model.net.SmoothL1Loss( \
['interp_U_reshaped'+pref, 'U_points'+pref, \
'Uv_point_weights'+pref, 'Uv_point_weights'+pref], \
'loss_Upoints'+pref, \
scale=cfg.BODY_UV_RCNN.POINT_REGRESSION_WEIGHTS / cfg.NUM_GPUS)
bias_init=const_fill(0.0)
)
# Increase heatmap output size via bilinear upsampling
blob_outputs = []
for name in ['AnnIndex', 'Index_UV', 'U', 'V']:
blob_outputs.append(
model.BilinearInterpolation(
name + '_lowres',
name + '_estimated' if name in ['U', 'V'] else name,
cfg.BODY_UV_RCNN.NUM_PATCHES + 1,
cfg.BODY_UV_RCNN.NUM_PATCHES + 1,
cfg.BODY_UV_RCNN.UP_SCALE
)
)

return blob_outputs


def add_body_uv_losses(model):
"""Add DensePose body UV specific losses."""
# Pool estimated IUV points via bilinear interpolation.
for name in ['U', 'V', 'Index_UV']:
model.PoolPointsInterp(
[
name + '_estimated' if name in ['U', 'V'] else name,
'body_uv_coords_xy'
],
['interp_' + name]
)

loss_Vpoints = model.net.SmoothL1Loss( \
['interp_V_reshaped'+pref, 'V_points'+pref, \
'Uv_point_weights'+pref, 'Uv_point_weights'+pref], \
'loss_Vpoints'+pref, scale=cfg.BODY_UV_RCNN.POINT_REGRESSION_WEIGHTS / cfg.NUM_GPUS)
## Add the losses.
loss_gradients = blob_utils.get_loss_gradients(model, \
[ loss_Upoints, loss_Vpoints, loss_seg_AnnIndex, loss_IndexUVPoints])
model.losses = list(set(model.losses + \
['loss_Upoints'+pref , 'loss_Vpoints'+pref , \
'loss_seg_AnnIndex'+pref ,'loss_IndexUVPoints'+pref]))
# Compute spatial softmax normalized probabilities, after which
# cross-entropy loss is computed for semantic parts classification.
probs_AnnIndex, loss_AnnIndex = model.net.SpatialSoftmaxWithLoss(
[
'AnnIndex',
'body_uv_parts', 'body_uv_parts_weights'
],
['probs_AnnIndex', 'loss_AnnIndex'],
scale=cfg.BODY_UV_RCNN.INDEX_WEIGHTS / cfg.NUM_GPUS
)
# Softmax loss for surface patch classification.
probs_I_points, loss_I_points = model.net.SoftmaxWithLoss(
['interp_Index_UV', 'body_uv_I_points'],
['probs_I_points', 'loss_I_points'],
scale=cfg.BODY_UV_RCNN.PART_WEIGHTS / cfg.NUM_GPUS,
spatial=0
)
## Smooth L1 loss for each patch-specific UV coordinates regression.
# Reshape U,V blobs of both interpolated and ground-truth to compute
# summarized (instead of averaged) SmoothL1Loss.
loss_UV = list()
model.net.Reshape(
['body_uv_point_weights'],
['UV_point_weights', 'body_uv_point_weights_shape'],
shape=(1, -1, cfg.BODY_UV_RCNN.NUM_PATCHES + 1)
)
for name in ['U', 'V']:
# Reshape U/V coordinates of both interpolated points and ground-truth
# points from (#points, #patches) to (1, #points, #patches).
model.net.Reshape(
['body_uv_' + name + '_points'],
[name + '_points', 'body_uv_' + name + '_points_shape'],
shape=(1, -1, cfg.BODY_UV_RCNN.NUM_PATCHES + 1)
)
model.net.Reshape(
['interp_' + name],
['interp_' + name + '_reshaped', 'interp_' + name + 'shape'],
shape=(1, -1, cfg.BODY_UV_RCNN.NUM_PATCHES + 1)
)
# Compute summarized SmoothL1Loss of all points.
loss_UV.append(
model.net.SmoothL1Loss(
[
'interp_' + name + '_reshaped', name + '_points',
'UV_point_weights', 'UV_point_weights'
],
'loss_' + name + '_points',
scale=cfg.BODY_UV_RCNN.POINT_REGRESSION_WEIGHTS / cfg.NUM_GPUS
)
)
# Add all losses to compute gradients
loss_gradients = blob_utils.get_loss_gradients(
model, [loss_AnnIndex, loss_I_points] + loss_UV
)
# Update model training losses
model.AddLosses(
['loss_' + name for name in ['AnnIndex', 'I_points', 'U_points', 'V_points']]
)

return loss_gradients

Expand All @@ -155,17 +151,17 @@ def add_body_uv_losses(model, pref=''):
# Body UV heads
# ---------------------------------------------------------------------------- #

def add_ResNet_roi_conv5_head_for_bodyUV(
model, blob_in, dim_in, spatial_scale
):
def add_ResNet_roi_conv5_head_for_bodyUV(model, blob_in, dim_in, spatial_scale):
"""Add a ResNet "conv5" / "stage5" head for body UV prediction."""
model.RoIFeatureTransform(
blob_in, '_[body_uv]_pool5',
blob_in,
'_[body_uv]_pool5',
blob_rois='body_uv_rois',
method=cfg.BODY_UV_RCNN.ROI_XFORM_METHOD,
resolution=cfg.BODY_UV_RCNN.ROI_XFORM_RESOLUTION,
sampling_ratio=cfg.BODY_UV_RCNN.ROI_XFORM_SAMPLING_RATIO,
spatial_scale=spatial_scale)
spatial_scale=spatial_scale
)
# Using the prefix '_[body_uv]_' to 'res5' enables initializing the head's
# parameters using pretrained 'res5' parameters if given (see
# utils.net.initialize_from_weights_file)
Expand All @@ -184,7 +180,7 @@ def add_ResNet_roi_conv5_head_for_bodyUV(


def add_roi_body_uv_head_v1convX(model, blob_in, dim_in, spatial_scale):
"""v1convX design: X * (conv)."""
"""Add a DensePose body UV head. v1convX design: X * (conv)."""
hidden_dim = cfg.BODY_UV_RCNN.CONV_HEAD_DIM
kernel_size = cfg.BODY_UV_RCNN.CONV_HEAD_KERNEL
pad_size = kernel_size // 2
Expand All @@ -208,7 +204,7 @@ def add_roi_body_uv_head_v1convX(model, blob_in, dim_in, spatial_scale):
stride=1,
pad=pad_size,
weight_init=(cfg.BODY_UV_RCNN.CONV_INIT, {'std': 0.01}),
bias_init=('ConstantFill', {'value': 0.})
bias_init=const_fill(0.0)
)
current = model.Relu(current, current)
dim_in = hidden_dim
Expand Down
Loading

0 comments on commit 723ab20

Please sign in to comment.