-
Notifications
You must be signed in to change notification settings - Fork 0
/
losses.py
157 lines (127 loc) · 6.77 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import torch
import torch.nn.functional as F
from image_warping import ImageWarping
from utils import intrinsics
# This code was adapted from: https://github.com/ClementPinard/SfmLearner-Pytorch
class ViewSynthesisLoss:
def __init__(self, device, rotation_mode='euler', padding_mode='zeros', scale=4):
"""Custom loss class implementing differentiable photometric reconstruction
loss (view synthesis) and smoothness loss (depth).
"""
self.device = device
self.warper = ImageWarping(rotation_mode, padding_mode)
# downscale intrinsic matrix parameters to match dataloader transforms
intrinsic = torch.from_numpy(intrinsics).to(device)
self.intrinsic = torch.cat((intrinsic[:2, :] / scale, intrinsic[2:, :]), dim=0)
def photometric_reconstruction_loss(self, tgt_img, depth, ref_imgs, poses):
"""Compute photometric reconstruction loss between reference images and a target image via
view-synthesis framework. Loss is computed at multiple scale spaces.
Args:
tgt_img: target frame image -- [B, 1, H, W]
depth: predicted depths -- [S, B, 1, H, W] - S is scale space
ref_imgs: reference frame images -- [Seq, B, 1, H, W] - Seq is the sequence length
poses: predicted poses -- [B, Seq, 6] - Seq is the sequence length
Returns:
reconstruction_loss: novel view synthesis loss
warped_imgs: warped reference frame images for visualization
"""
def one_scale(scale_depth):
"""Compute photometric loss at a single scale space."""
# interpolate target and reference images to match scale space
b, _, h, w = scale_depth.size()
tgt_img_scaled = F.interpolate(tgt_img, (h, w), mode='area')
ref_imgs_scaled = [F.interpolate(ref_img, (h, w), mode='area') for ref_img in ref_imgs]
# downscale intrinsic to match scale space
downscale = tgt_img.size(2) / h
intrinsic = self.intrinsic.clone().to(self.device)
intrinsic_scaled = torch.cat((intrinsic[:2, :] / downscale, intrinsic[2:, :]), dim=0)
warped_imgs = []
diff_maps = []
reconstruction_loss = 0
# compute loss over each reference-target pair in the sequence
for i, ref_img in enumerate(ref_imgs_scaled):
pose = poses[:, i] # [B, 2, 6] -> [B, 6]
# batch inverse warp
ref_warped_img, valid_points = self.warper.inverse_warp(
scale_depth[:, 0], ref_img, pose, intrinsic_scaled
)
assert (ref_warped_img.size() == tgt_img_scaled.size())
# compute absolute pixel difference on valid points only
masked_pixel_error = (tgt_img_scaled - ref_warped_img) * valid_points.unsqueeze(1).float()
reconstruction_loss += masked_pixel_error.abs().mean()
assert ((reconstruction_loss == reconstruction_loss).item() == 1)
# store first in each batch
warped_imgs.append(ref_warped_img[0])
diff_maps.append(masked_pixel_error[0])
return reconstruction_loss, warped_imgs, diff_maps
# compute loss across batch for each scale space
warped_results, diff_results = [], []
total_loss = 0
for d in depth:
loss, warped, diff = one_scale(d)
total_loss += loss
warped_results.append(warped)
diff_results.append(diff)
return total_loss, warped_results, diff_results
def inverse_warp_loss(self, tgt_img, depth, ref_img, poses):
"""Compute photometric reconstruction loss between a reference image and a target image via
view-synthesis framework. Loss is computed at multiple scale spaces.
Args:
tgt_img: target frame image -- [B, 1, H, W]
depth: predicted depths -- [S, B, 1, H, W] - S is scale space
ref_img: reference frame images -- [B, 1, H, W] - Seq is the sequence length
poses: ground truth homogenous pose -- [B, 4, 4] - Seq is the sequence length
Returns:
reconstruction_loss: novel view synthesis loss
warped_imgs: warped reference frame images for visualization
"""
def one_scale(scale_depth):
"""Compute photometric loss at a single scale space."""
# interpolate target and reference images to match scale space
b, _, h, w = scale_depth.size()
tgt_img_scaled = F.interpolate(tgt_img, (h, w), mode='area')
ref_img_scaled = F.interpolate(ref_img, (h, w), mode='area')
# downscale intrinsic to match scale space
downscale = tgt_img.size(2) / h
intrinsic = self.intrinsic.clone().to(self.device)
intrinsic_scaled = torch.cat((intrinsic[:2, :] / downscale, intrinsic[2:, :]), dim=0)
# compute photometric reconstruction loss
# batch inverse warp
ref_warped_img, valid_points = self.warper.inverse_warp_gt_pose(
scale_depth[:, 0], ref_img_scaled, poses, intrinsic_scaled
)
assert (ref_warped_img.size() == tgt_img_scaled.size())
# compute absolute pixel difference on valid points only
masked_pixel_error = (tgt_img_scaled - ref_warped_img) * valid_points.unsqueeze(1).float()
reconstruction_loss = masked_pixel_error.abs().mean()
assert ((reconstruction_loss == reconstruction_loss).item() == 1)
return reconstruction_loss, ref_warped_img[0], masked_pixel_error[0]
# compute loss across batch for each scale space
warped_results, diff_results = [], []
total_loss = 0
for d in depth:
loss, warped, diff = one_scale(d)
total_loss += loss
warped_results.append(warped)
diff_results.append(diff)
return total_loss, warped_results, diff_results
def smoothness_loss(self, pred_depth):
"""Compute smoothness loss over depth image.
Args:
pred_depth: predicted depth maps -- [S, B, 1, H, W]
"""
def gradient(pred):
"""Compute pixel gradients in x and y direction"""
D_dy = pred[:, :, 1:] - pred[:, :, :-1]
D_dx = pred[:, :, :, 1:] - pred[:, :, :, :-1]
return D_dx, D_dy
loss = 0
weight = 1.
# compute smoothness loss for depth maps in each scale space
for scaled_map in pred_depth:
dx, dy = gradient(scaled_map)
dx2, dxdy = gradient(dx)
dydx, dy2 = gradient(dy)
loss += (dx2.abs().mean() + dxdy.abs().mean() + dydx.abs().mean() + dy2.abs().mean()) * weight
weight /= 2.3
return loss