From dfbdc27202075b500577c64d3f0d6c8438b86cfd Mon Sep 17 00:00:00 2001
From: VitorGuizilini-TRI
<58576956+VitorGuizilini-TRI@users.noreply.github.com>
Date: Wed, 10 Jun 2020 09:50:03 -0700
Subject: [PATCH] DGP Multi-Cam + Velocity Loss + FP16 Inference + PackNetSlim
(#30)
* Support for multi-camera loading on the DGP dataset
* Support for fp16 at inference time
* Velocity loss (see VelSupModel)
* PackNetSlim01 (faster version of PackNet)
---
configs/default_config.py | 18 +-
configs/eval_ddad.yaml | 2 +-
configs/overfit_ddad.yaml | 6 +-
configs/train_ddad.yaml | 6 +-
.../packnet_sfm/datasets/dgp_dataset.html | 128 +-
.../packnet_sfm/datasets/kitti_dataset.html | 13 +-
.../packnet_sfm/datasets/transforms.html | 9 +-
.../losses/multiview_photometric_loss.html | 2 +-
.../packnet_sfm/models/SelfSupModel.html | 24 +-
.../packnet_sfm/models/SemiSupModel.html | 22 +-
.../_modules/packnet_sfm/models/SfmModel.html | 47 +-
.../packnet_sfm/models/model_utils.html | 30 +
.../packnet_sfm/models/model_wrapper.html | 54 +-
.../packnet_sfm/networks/depth/PackNet01.html | 8 +-
.../packnet_sfm/trainers/base_trainer.html | 11 +-
docs/_modules/packnet_sfm/utils/config.html | 11 +-
docs/_modules/packnet_sfm/utils/logging.html | 11 +-
docs/_modules/packnet_sfm/utils/save.html | 3 -
docs/_modules/scripts/eval.html | 9 +-
docs/_modules/scripts/infer.html | 111 +-
docs/_static/basic.css | 121 +-
.../{jquery-3.4.1.js => jquery-3.5.1.js} | 1238 ++++++++++-------
docs/_static/jquery.js | 4 +-
docs/datasets/datasets.DGPDataset.html | 42 +
docs/datasets/datasets.KITTIDataset.html | 3 +-
docs/genindex.html | 59 +-
.../losses.multiview_photometric_loss.html | 2 +-
docs/models/models.SelfSupModel.html | 28 +-
docs/models/models.SemiSupModel.html | 20 -
docs/models/models.SfmModel.html | 62 +-
docs/models/models.Utilities.html | 17 +
docs/models/models.Wrapper.html | 12 +-
docs/objects.inv | Bin 4180 -> 4169 bytes
docs/scripts/scripts.eval.html | 2 +-
docs/scripts/scripts.infer.html | 40 +-
docs/searchindex.js | 2 +-
docs/trainers/trainers.BaseTrainer.html | 2 +-
packnet_sfm/datasets/dgp_dataset.py | 128 +-
packnet_sfm/datasets/kitti_dataset.py | 13 +-
packnet_sfm/datasets/transforms.py | 9 +-
.../losses/multiview_photometric_loss.py | 2 +-
packnet_sfm/losses/velocity_loss.py | 42 +
packnet_sfm/models/SelfSupModel.py | 24 +-
packnet_sfm/models/SemiSupModel.py | 22 +-
packnet_sfm/models/SfmModel.py | 47 +-
packnet_sfm/models/VelSupModel.py | 52 +
packnet_sfm/models/model_utils.py | 30 +
packnet_sfm/models/model_wrapper.py | 54 +-
packnet_sfm/networks/depth/PackNet01.py | 8 +-
packnet_sfm/networks/depth/PackNetSlim01.py | 183 +++
packnet_sfm/trainers/base_trainer.py | 11 +-
packnet_sfm/trainers/horovod_trainer.py | 6 +-
packnet_sfm/utils/config.py | 11 +-
packnet_sfm/utils/logging.py | 11 +-
packnet_sfm/utils/save.py | 3 -
scripts/eval.py | 9 +-
scripts/infer.py | 111 +-
57 files changed, 1972 insertions(+), 983 deletions(-)
rename docs/_static/{jquery-3.4.1.js => jquery-3.5.1.js} (91%)
create mode 100644 packnet_sfm/losses/velocity_loss.py
create mode 100644 packnet_sfm/models/VelSupModel.py
create mode 100644 packnet_sfm/networks/depth/PackNetSlim01.py
diff --git a/configs/default_config.py b/configs/default_config.py
index 2e2b23e2..fa2c09d3 100644
--- a/configs/default_config.py
+++ b/configs/default_config.py
@@ -80,11 +80,11 @@
########################################################################################################################
cfg.model.loss = CN()
#
-cfg.model.loss.num_scales = 4 # Number of inverse depth scales to use
-cfg.model.loss.progressive_scaling = 0.0 # Training percentage to decay number of scales
-cfg.model.loss.flip_lr_prob = 0.5 # Probablity of horizontal flippping
-cfg.model.loss.rotation_mode = 'euler' # Rotation mode
-cfg.model.loss.upsample_depth_maps = True # Resize depth maps to highest resolution
+cfg.model.loss.num_scales = 4 # Number of inverse depth scales to use
+cfg.model.loss.progressive_scaling = 0.0 # Training percentage to decay number of scales
+cfg.model.loss.flip_lr_prob = 0.5 # Probablity of horizontal flippping
+cfg.model.loss.rotation_mode = 'euler' # Rotation mode
+cfg.model.loss.upsample_depth_maps = True # Resize depth maps to highest resolution
#
cfg.model.loss.ssim_loss_weight = 0.85 # SSIM loss weight
cfg.model.loss.occ_reg_weight = 0.1 # Occlusion regularizer loss weight
@@ -97,6 +97,8 @@
cfg.model.loss.padding_mode = 'zeros' # Photometric loss padding mode
cfg.model.loss.automask_loss = True # Automasking to remove static pixels
#
+cfg.model.loss.velocity_loss_weight = 0.1 # Velocity supervision loss weight
+#
cfg.model.loss.supervised_method = 'sparse-l1' # Method for depth supervision
cfg.model.loss.supervised_num_scales = 4 # Number of scales for supervised learning
cfg.model.loss.supervised_loss_weight = 0.9 # Supervised loss weight
@@ -138,7 +140,7 @@
cfg.datasets.train.path = [] # Training data path
cfg.datasets.train.split = [] # Training split
cfg.datasets.train.depth_type = [''] # Training depth type
-cfg.datasets.train.cameras = [] # Training cameras
+cfg.datasets.train.cameras = [[]] # Training cameras (double list, one for each dataset)
cfg.datasets.train.repeat = [1] # Number of times training dataset is repeated per epoch
cfg.datasets.train.num_logs = 5 # Number of training images to log
########################################################################################################################
@@ -153,7 +155,7 @@
cfg.datasets.validation.path = [] # Validation data path
cfg.datasets.validation.split = [] # Validation split
cfg.datasets.validation.depth_type = [''] # Validation depth type
-cfg.datasets.validation.cameras = [] # Validation cameras
+cfg.datasets.validation.cameras = [[]] # Validation cameras (double list, one for each dataset)
cfg.datasets.validation.num_logs = 5 # Number of validation images to log
########################################################################################################################
### DATASETS.TEST
@@ -167,7 +169,7 @@
cfg.datasets.test.path = [] # Test data path
cfg.datasets.test.split = [] # Test split
cfg.datasets.test.depth_type = [''] # Test depth type
-cfg.datasets.test.cameras = [] # Test cameras
+cfg.datasets.test.cameras = [[]] # Test cameras (double list, one for each dataset)
cfg.datasets.test.num_logs = 5 # Number of test images to log
########################################################################################################################
### THESE SHOULD NOT BE CHANGED
diff --git a/configs/eval_ddad.yaml b/configs/eval_ddad.yaml
index ca490fee..bc1e153b 100644
--- a/configs/eval_ddad.yaml
+++ b/configs/eval_ddad.yaml
@@ -18,7 +18,7 @@ datasets:
path: ['/data/datasets/DDAD/ddad.json']
split: ['val']
depth_type: ['lidar']
- cameras: ['camera_01']
+ cameras: [['camera_01']]
save:
folder: '/data/save'
viz: True
diff --git a/configs/overfit_ddad.yaml b/configs/overfit_ddad.yaml
index 7cfdd482..72e4447a 100644
--- a/configs/overfit_ddad.yaml
+++ b/configs/overfit_ddad.yaml
@@ -31,17 +31,17 @@ datasets:
path: ['/data/datasets/DDAD_tiny/ddad_tiny.json']
split: ['train']
depth_type: ['lidar']
- cameras: ['camera_01']
+ cameras: [['camera_01']]
repeat: [500]
validation:
dataset: ['DGP']
path: ['/data/datasets/DDAD_tiny/ddad_tiny.json']
split: ['train']
depth_type: ['lidar']
- cameras: ['camera_01']
+ cameras: [['camera_01']]
test:
dataset: ['DGP']
path: ['/data/datasets/DDAD_tiny/ddad_tiny.json']
split: ['train']
depth_type: ['lidar']
- cameras: ['camera_01']
+ cameras: [['camera_01']]
diff --git a/configs/train_ddad.yaml b/configs/train_ddad.yaml
index ea4c7c31..a047b48f 100644
--- a/configs/train_ddad.yaml
+++ b/configs/train_ddad.yaml
@@ -30,7 +30,7 @@ datasets:
path: ['/data/datasets/DDAD/ddad.json']
split: ['train']
depth_type: ['lidar']
- cameras: ['camera_01']
+ cameras: [['camera_01']]
repeat: [5]
validation:
num_workers: 8
@@ -38,11 +38,11 @@ datasets:
path: ['/data/datasets/DDAD/ddad.json']
split: ['val']
depth_type: ['lidar']
- cameras: ['camera_01']
+ cameras: [['camera_01']]
test:
num_workers: 8
dataset: ['DGP']
path: ['/data/datasets/DDAD/ddad.json']
split: ['val']
depth_type: ['lidar']
- cameras: ['camera_01']
+ cameras: [['camera_01']]
diff --git a/docs/_modules/packnet_sfm/datasets/dgp_dataset.html b/docs/_modules/packnet_sfm/datasets/dgp_dataset.html
index 14efb8da..7459d2b2 100644
--- a/docs/_modules/packnet_sfm/datasets/dgp_dataset.html
+++ b/docs/_modules/packnet_sfm/datasets/dgp_dataset.html
@@ -165,10 +165,16 @@
# Copyright 2020 Toyota Research Institute. All rights reserved.
+import os
import torch
-from packnet_sfm.utils.misc import make_list
-from packnet_sfm.utils.types import is_tensor
+import numpy as np
+
from dgp.datasets.synchronized_dataset import SynchronizedSceneDataset
+from dgp.utils.camera import Camera, generate_depth_map
+from dgp.utils.geometry import Pose
+
+from packnet_sfm.utils.misc import make_list
+from packnet_sfm.utils.types import is_tensor, is_numpy, is_list
########################################################################################################################
#### FUNCTIONS
@@ -189,7 +195,24 @@ Source code for packnet_sfm.datasets.dgp_dataset
else:
# Stack torch tensors
if is_tensor(sample[0][key]):
- stacked_sample[key] = torch.cat([s[key].unsqueeze(0) for s in sample], 0)
+ stacked_sample[key] = torch.stack([s[key] for s in sample], 0)
+ # Stack numpy arrays
+ elif is_numpy(sample[0][key]):
+ stacked_sample[key] = np.stack([s[key] for s in sample], 0)
+ # Stack list
+ elif is_list(sample[0][key]):
+ stacked_sample[key] = []
+ # Stack list of torch tensors
+ if is_tensor(sample[0][key][0]):
+ for i in range(len(sample[0][key])):
+ stacked_sample[key].append(
+ torch.stack([s[key][i] for s in sample], 0))
+ # Stack list of numpy arrays
+ if is_numpy(sample[0][key][0]):
+ for i in range(len(sample[0][key])):
+ stacked_sample[key].append(
+ np.stack([s[key][i] for s in sample], 0))
+
# Return stacked sample
return stacked_sample
@@ -231,6 +254,7 @@ Source code for packnet_sfm.datasets.dgp_dataset
forward_context=0,
data_transform=None,
):
+
self.path = path
self.split = split
self.dataset_idx = 0
@@ -241,6 +265,7 @@
Source code for packnet_sfm.datasets.dgp_dataset
self.num_cameras = len(cameras)
self.data_transform = data_transform
+
self.depth_type = depth_type
self.with_depth = depth_type is not None
self.with_pose = with_pose
self.with_semantic = with_semantic
@@ -250,11 +275,57 @@
Source code for packnet_sfm.datasets.dgp_dataset
datum_names=cameras,
backward_context=back_context,
forward_context=forward_context,
-
generate_depth_from_datum=depth_type,
requested_annotations=None,
only_annotated_datums=False,
)
+
[docs] def generate_depth_map(self, sample_idx, datum_idx, filename):
+
"""
+
Generates the depth map for a camera by projecting LiDAR information.
+
It also caches the depth map following DGP folder structure, so it's not recalculated
+
+
Parameters
+
----------
+
sample_idx : int
+
sample index
+
datum_idx : int
+
Datum index
+
filename :
+
Filename used for loading / saving
+
+
Returns
+
-------
+
depth : np.array [H, W]
+
Depth map for that datum in that sample
+
"""
+
# Generate depth filename
+
filename = '{}/{}.npz'.format(
+
os.path.dirname(self.path), filename.format('depth/{}'.format(self.depth_type)))
+
# Load and return if exists
+
if os.path.exists(filename):
+
return np.load(filename)['depth']
+
# Otherwise, create, save and return
+
else:
+
# Get pointcloud
+
scene_idx, sample_idx_in_scene, _ = self.dataset.dataset_item_index[sample_idx]
+
pc_datum_idx_in_sample = self.dataset.get_datum_index_for_datum_name(
+
scene_idx, sample_idx_in_scene, self.depth_type)
+
pc_datum_data = self.dataset.get_point_cloud_from_datum(
+
scene_idx, sample_idx_in_scene, pc_datum_idx_in_sample)
+
# Create camera
+
camera_rgb = self.get_current('rgb', datum_idx)
+
camera_pose = self.get_current('pose', datum_idx)
+
camera_intrinsics = self.get_current('intrinsics', datum_idx)
+
camera = Camera(K=camera_intrinsics, p_cw=camera_pose.inverse())
+
# Generate depth map
+
world_points = pc_datum_data['pose'] * pc_datum_data['point_cloud']
+
depth = generate_depth_map(camera, world_points, camera_rgb.size[::-1])
+
# Save depth map
+
os.makedirs(os.path.dirname(filename), exist_ok=True)
+
np.savez_compressed(filename, depth=depth)
+
# Return depth map
+
return depth
+
[docs] def get_current(self, key, sensor_idx):
"""Return current timestep of a key from a sensor"""
return self.sample_dgp[self.bwd][sensor_idx][key]
@@ -275,6 +346,29 @@
Source code for packnet_sfm.datasets.dgp_dataset
"""Get both backward and forward contexts"""
return self.get_backward(key, sensor_idx) + self.get_forward(key, sensor_idx)
+
[docs] def get_filename(self, sample_idx, datum_idx):
+
"""
+
Returns the filename for an index, following DGP structure
+
+
Parameters
+
----------
+
sample_idx : int
+
Sample index
+
datum_idx : int
+
Datum index
+
+
Returns
+
-------
+
filename : str
+
Filename for the datum in that sample
+
"""
+
scene_idx, sample_idx_in_scene, datum_indices = self.dataset.dataset_item_index[sample_idx]
+
scene_dir = self.dataset.get_scene_directory(scene_idx)
+
filename = self.dataset.get_datum(
+
scene_idx, sample_idx_in_scene, datum_indices[datum_idx]).datum.image.filename
+
return os.path.splitext(os.path.join(os.path.basename(scene_dir),
+
filename.replace('rgb', '{}')))[0]
+
def __len__(self):
"""Length of dataset"""
return len(self.dataset)
@@ -292,27 +386,45 @@
Source code for packnet_sfm.datasets.dgp_dataset
'idx': idx,
'dataset_idx': self.dataset_idx,
'sensor_name': self.get_current('datum_name', i),
-
'filename': '%s_%010d' % (self.split, idx),
+
#
+
'filename': self.get_filename(idx, i),
+
'splitname': '%s_%010d' % (self.split, idx),
#
'rgb': self.get_current('rgb', i),
'intrinsics': self.get_current('intrinsics', i),
}
+
# If depth is returned
if self.with_depth:
data.update({
-
'depth': self.get_current('depth', i),
+
'depth': self.generate_depth_map(idx, i, data['filename'])
})
+
# If pose is returned
if self.with_pose:
data.update({
-
'extrinsics': [pose.matrix for pose in self.get_current('extrinsics', i)],
-
'pose': [pose.matrix for pose in self.get_current('pose', i)],
+
'extrinsics': self.get_current('extrinsics', i).matrix,
+
'pose': self.get_current('pose', i).matrix,
})
+
# If context is returned
if self.has_context:
data.update({
'rgb_context': self.get_context('rgb', i),
})
+
# If context pose is returned
+
if self.with_pose:
+
# Get original values to calculate relative motion
+
orig_extrinsics = Pose.from_matrix(data['extrinsics'])
+
orig_pose = Pose.from_matrix(data['pose'])
+
data.update({
+
'extrinsics_context':
+
[(orig_extrinsics.inverse() * extrinsics).matrix
+
for extrinsics in self.get_context('extrinsics', i)],
+
'pose_context':
+
[(orig_pose.inverse() * pose).matrix
+
for pose in self.get_context('pose', i)],
+
})
sample.append(data)
diff --git a/docs/_modules/packnet_sfm/datasets/kitti_dataset.html b/docs/_modules/packnet_sfm/datasets/kitti_dataset.html
index f720cd8e..f9ecb61e 100644
--- a/docs/_modules/packnet_sfm/datasets/kitti_dataset.html
+++ b/docs/_modules/packnet_sfm/datasets/kitti_dataset.html
@@ -223,8 +223,6 @@
Source code for packnet_sfm.datasets.kitti_dataset
Split file, with paths to the images to be used
train : bool
True if the dataset will be used for training
-
mode : str
-
Dataset mode (stereo or mono)
data_transform : Function
Transformations applied to the sample
depth_type : str
@@ -238,7 +236,7 @@
Source code for packnet_sfm.datasets.kitti_dataset
strides : tuple
List of context strides
"""
-
def __init__(self, root_dir, file_list, train=True, mode='mono',
+
def __init__(self, root_dir, file_list, train=True,
data_transform=None, depth_type=None, with_pose=False,
back_context=0, forward_context=0, strides=(1,)):
# Assertions
@@ -459,9 +457,12 @@
Source code for packnet_sfm.datasets.kitti_dataset
def
_get_oxts_file(image_file):
"""Gets the oxts file from an image file."""
# find oxts pose file
-
oxts_file = image_file.replace(IMAGE_FOLDER['left'], OXTS_POSE_DATA)
-
oxts_file = oxts_file.replace('png', 'txt')
-
return oxts_file
+
for cam in ['left', 'right']:
+
# Check for both cameras, if found replace and return file name
+
if IMAGE_FOLDER[cam] in image_file:
+
return image_file.replace(IMAGE_FOLDER[cam], OXTS_POSE_DATA).replace('.png', '.txt')
+
# Something went wrong (invalid image file)
+
raise ValueError('Invalid KITTI path for pose supervision.')
def _get_oxts_data(self, image_file):
"""Gets the oxts data from an image file."""
diff --git a/docs/_modules/packnet_sfm/datasets/transforms.html b/docs/_modules/packnet_sfm/datasets/transforms.html
index d15a6d3c..e0437de1 100644
--- a/docs/_modules/packnet_sfm/datasets/transforms.html
+++ b/docs/_modules/packnet_sfm/datasets/transforms.html
@@ -189,7 +189,8 @@
Source code for packnet_sfm.datasets.transforms
<
sample : dict
Augmented sample
"""
-
sample = resize_sample(sample, image_shape)
+
if len(image_shape) > 0:
+
sample = resize_sample(sample, image_shape)
sample = duplicate_sample(sample)
if len(jittering) > 0:
sample = colorjitter_sample(sample, jittering)
@@ -212,7 +213,8 @@
Source code for packnet_sfm.datasets.transforms
<
sample : dict
Augmented sample
"""
- sample['rgb'] = resize_image(sample['rgb'], image_shape)
+ if len(image_shape) > 0:
+ sample['rgb'] = resize_image(sample['rgb'], image_shape)
sample = to_tensor_sample(sample)
return sample
@@ -232,7 +234,8 @@
Source code for packnet_sfm.datasets.transforms
<
sample : dict
Augmented sample
"""
- sample['rgb'] = resize_image(sample['rgb'], image_shape)
+ if len(image_shape) > 0:
+ sample['rgb'] = resize_image(sample['rgb'], image_shape)
sample = to_tensor_sample(sample)
return sample
diff --git a/docs/_modules/packnet_sfm/losses/multiview_photometric_loss.html b/docs/_modules/packnet_sfm/losses/multiview_photometric_loss.html
index c6389f26..df6f1620 100644
--- a/docs/_modules/packnet_sfm/losses/multiview_photometric_loss.html
+++ b/docs/_modules/packnet_sfm/losses/multiview_photometric_loss.html
@@ -178,7 +178,7 @@
Source code for packnet_sfm.losses.multiview_photometric_loss
[docs]def SSIM(x, y, C1=1e-4, C2=9e-4, kernel_size=3, stride=1):
"""
-
Structural SIMlilarity (SSIM) distance between two images.
+
Structural SIMilarity (SSIM) distance between two images.
Parameters
----------
diff --git a/docs/_modules/packnet_sfm/models/SelfSupModel.html b/docs/_modules/packnet_sfm/models/SelfSupModel.html
index dfd9a143..80e6e81b 100644
--- a/docs/_modules/packnet_sfm/models/SelfSupModel.html
+++ b/docs/_modules/packnet_sfm/models/SelfSupModel.html
@@ -177,16 +177,12 @@
Source code for packnet_sfm.models.SelfSupModel
<
Parameters
----------
-
depth_net : nn.Module
-
Depth network to be used
-
pose_net : nn.Module
-
Pose network to be used
kwargs : dict
Extra parameters
"""
-
def __init__(self, depth_net=None, pose_net=None, **kwargs):
+
def __init__(self, **kwargs):
# Initializes SfmModel
-
super().__init__(depth_net, pose_net, **kwargs)
+
super().__init__(**kwargs)
# Initializes the photometric loss
self._photometric_loss = MultiViewPhotometricLoss(**kwargs)
@@ -198,22 +194,6 @@
Source code for packnet_sfm.models.SelfSupModel
<
**self._photometric_loss.logs
}
-
@property
-
def requires_depth_net(self):
-
return True
-
-
@property
-
def requires_pose_net(self):
-
return True
-
-
@property
-
def requires_gt_depth(self):
-
return False
-
-
@property
-
def requires_gt_pose(self):
-
return False
-
[docs] def self_supervised_loss(self, image, ref_images, inv_depths, poses,
intrinsics, return_logs=False, progress=0.0):
"""
diff --git a/docs/_modules/packnet_sfm/models/SemiSupModel.html b/docs/_modules/packnet_sfm/models/SemiSupModel.html
index d9d3ab11..06324768 100644
--- a/docs/_modules/packnet_sfm/models/SemiSupModel.html
+++ b/docs/_modules/packnet_sfm/models/SemiSupModel.html
@@ -186,6 +186,7 @@
Source code for packnet_sfm.models.SemiSupModel
<
Extra parameters
"""
def __init__(self, supervised_loss_weight=0.9, **kwargs):
+
# Initializes SelfSupModel
super().__init__(**kwargs)
# If supervision weight is 0.0, use SelfSupModel directly
assert 0. < supervised_loss_weight <= 1., "Model requires (0, 1] supervision"
@@ -193,6 +194,11 @@
Source code for packnet_sfm.models.SemiSupModel
<
self.supervised_loss_weight = supervised_loss_weight
self._supervised_loss = SupervisedLoss(**kwargs)
+
# Pose network is only required if there is self-supervision
+
self._network_requirements['pose_net'] = self.supervised_loss_weight < 1
+
# GT depth is only required if there is supervision
+
self._train_requirements['gt_depth'] = self.supervised_loss_weight > 0
+
@property
def logs(self):
"""Return logs."""
@@ -201,22 +207,6 @@
Source code for packnet_sfm.models.SemiSupModel
<
**self._supervised_loss.logs
}
-
@property
-
def requires_depth_net(self):
-
return True
-
-
@property
-
def requires_pose_net(self):
-
return self.supervised_loss_weight < 1.
-
-
@property
-
def requires_gt_depth(self):
-
return self.supervised_loss_weight > 0.
-
-
@property
-
def requires_gt_pose(self):
-
return False
-
[docs] def supervised_loss(self, inv_depths, gt_inv_depths,
return_logs=False, progress=0.0):
"""
diff --git a/docs/_modules/packnet_sfm/models/SfmModel.html b/docs/_modules/packnet_sfm/models/SfmModel.html
index 208b0e48..c2d14059 100644
--- a/docs/_modules/packnet_sfm/models/SfmModel.html
+++ b/docs/_modules/packnet_sfm/models/SfmModel.html
@@ -187,7 +187,7 @@
Source code for packnet_sfm.models.SfmModel
flip_lr_prob : float
Probability of flipping when using the depth network
upsample_depth_maps : bool
- True if detph map scales are upsampled to highest resolution
+ True if depth map scales are upsampled to highest resolution
kwargs : dict
Extra parameters
"""
@@ -203,6 +203,15 @@ Source code for packnet_sfm.models.SfmModel
self._logs = {}
self._losses = {}
+ self._network_requirements = {
+ 'depth_net': True, # Depth network required
+ 'pose_net': True, # Pose network required
+ }
+ self._train_requirements = {
+ 'gt_depth': False, # No ground-truth depth required
+ 'gt_pose': False, # No ground-truth pose required
+ }
+
@property
def logs(self):
"""Return logs."""
@@ -218,25 +227,41 @@ Source code for packnet_sfm.models.SfmModel
self._losses[key] = val.detach()
@property
- def requires_depth_net(self):
- return True
+ def network_requirements(self):
+ """
+ Networks required to run the model
- @property
- def requires_pose_net(self):
- return True
+ Returns
+ -------
+ requirements : dict
+ depth_net : bool
+ Whether a depth network is required by the model
+ pose_net : bool
+ Whether a depth network is required by the model
+ """
+ return self._network_requirements
@property
- def requires_gt_depth(self):
- return False
+ def train_requirements(self):
+ """
+ Information required by the model at training stage
- @property
- def requires_gt_pose(self):
- return False
+ Returns
+ -------
+ requirements : dict
+ gt_depth : bool
+ Whether ground truth depth is required by the model at training time
+ gt_pose : bool
+ Whether ground truth pose is required by the model at training time
+ """
+ return self._train_requirements
[docs] def add_depth_net(self, depth_net):
+
"""Add a depth network to the model"""
self.depth_net = depth_net
[docs] def add_pose_net(self, pose_net):
+
"""Add a pose network to the model"""
self.pose_net = pose_net
[docs] def compute_inv_depths(self, image):
diff --git a/docs/_modules/packnet_sfm/models/model_utils.html b/docs/_modules/packnet_sfm/models/model_utils.html
index d11f0fb3..e9bd6828 100644
--- a/docs/_modules/packnet_sfm/models/model_utils.html
+++ b/docs/_modules/packnet_sfm/models/model_utils.html
@@ -165,6 +165,7 @@
Source code for packnet_sfm.models.model_utils
# Copyright 2020 Toyota Research Institute. All rights reserved.
+from packnet_sfm.utils.types import is_tensor, is_list, is_numpy
[docs]def merge_outputs(*outputs):
"""
@@ -199,6 +200,35 @@
Source code for packnet_sfm.models.model_utils
'Adding duplicated key {}'.format(key)
merge[key] = val
return merge
+
+
+
[docs]def stack_batch(batch):
+
"""
+
Stack multi-camera batches (B,N,C,H,W becomes BN,C,H,W)
+
+
Parameters
+
----------
+
batch : dict
+
Batch
+
+
Returns
+
-------
+
batch : dict
+
Stacked batch
+
"""
+
# If there is multi-camera information
+
if len(batch['rgb'].shape) == 5:
+
assert batch['rgb'].shape[0] == 1, 'Only batch size 1 is supported for multi-cameras'
+
# Loop over all keys
+
for key in batch.keys():
+
# If list, stack every item
+
if is_list(batch[key]):
+
if is_tensor(batch[key][0]) or is_numpy(batch[key][0]):
+
batch[key] = [sample[0] for sample in batch[key]]
+
# Else, stack single item
+
else:
+
batch[key] = batch[key][0]
+
return batch
diff --git a/docs/_modules/packnet_sfm/models/model_wrapper.html b/docs/_modules/packnet_sfm/models/model_wrapper.html
index ef000b77..c1eae876 100644
--- a/docs/_modules/packnet_sfm/models/model_wrapper.html
+++ b/docs/_modules/packnet_sfm/models/model_wrapper.html
@@ -184,6 +184,7 @@
Source code for packnet_sfm.models.model_wrapper
from packnet_sfm.utils.reduce import all_reduce_metrics, reduce_dict, \
create_dict, average_loss_and_metrics
from packnet_sfm.utils.save import save_depth
+
from packnet_sfm.models.model_utils import stack_batch
[docs]class ModelWrapper(torch.nn.Module):
@@ -223,7 +224,10 @@
Source code for packnet_sfm.models.model_wrapper
# Prepare datasets
if load_datasets:
-
self.prepare_datasets()
+
# Requirements for validation (we only evaluate depth for now)
+
validation_requirements = {'gt_depth': True, 'gt_pose': False}
+
test_requirements = validation_requirements
+
self.prepare_datasets(validation_requirements, test_requirements)
# Preparations done
self.config.prepared = True
@@ -241,20 +245,24 @@
Source code for packnet_sfm.models.model_wrapper
if 'epoch' in resume:
self.current_epoch = resume['epoch']
-
[docs] def prepare_datasets(self):
+
[docs] def prepare_datasets(self, validation_requirements, test_requirements):
"""Prepare datasets for training, validation and test."""
-
# Prepare datasets
print0(pcolor('### Preparing Datasets', 'green'))
augmentation = self.config.datasets.augmentation
+
# Setup train dataset (requirements are given by the model itself)
self.train_dataset = setup_dataset(
self.config.datasets.train, 'train',
-
self.model.requires_gt_depth, **augmentation)
+
self.model.train_requirements, **augmentation)
+
# Setup validation dataset
self.validation_dataset = setup_dataset(
-
self.config.datasets.validation, 'validation', **augmentation)
+
self.config.datasets.validation, 'validation',
+
validation_requirements, **augmentation)
+
# Setup test dataset
self.test_dataset = setup_dataset(
-
self.config.datasets.test, 'test', **augmentation)
+
self.config.datasets.test, 'test',
+
test_requirements, **augmentation)
@property
def depth_net(self):
@@ -272,12 +280,17 @@
Source code for packnet_sfm.models.model_wrapper
params = OrderedDict()
for param in self.optimizer.param_groups:
params['{}_learning_rate'.format(param['name'].lower())] = param['lr']
-
params['progress'] = self.current_epoch / self.config.arch.max_epochs
+
params['progress'] = self.progress
return {
**params,
**self.model.logs,
}
+
@property
+
def progress(self):
+
"""Returns training progress (current epoch / max. number of epochs)"""
+
return self.current_epoch / self.config.arch.max_epochs
+