From e2283bcbd6ca5d2ef16d2cbdfe7d792979764631 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Sat, 6 Apr 2024 13:30:03 -0700 Subject: [PATCH] more intermediate work to switch to torchworld --- configs/simplebev3d_multi_pose.py | 2 +- torchdrive/tasks/bev.py | 44 ++++++++++++++++++++---------- torchdrive/train_config.py | 5 ++-- torchworld/models/simplebev_3d.py | 15 +++++----- torchworld/structures/grid.py | 24 ++++++---------- torchworld/structures/test_grid.py | 15 ++++++++-- torchworld/transforms/simplebev.py | 5 ++-- 7 files changed, 63 insertions(+), 47 deletions(-) diff --git a/configs/simplebev3d_multi_pose.py b/configs/simplebev3d_multi_pose.py index 706aa88..b9d06b8 100644 --- a/configs/simplebev3d_multi_pose.py +++ b/configs/simplebev3d_multi_pose.py @@ -35,7 +35,7 @@ autolabel_path=None, # "/mnt/ext3/autolabel2", mask_path="n/a", # only used for rice dataset num_workers=4, - batch_size=6, + batch_size=4, # tasks det=False, ae=False, diff --git a/torchdrive/tasks/bev.py b/torchdrive/tasks/bev.py index 6ba3395..4a2ed3a 100644 --- a/torchdrive/tasks/bev.py +++ b/torchdrive/tasks/bev.py @@ -9,7 +9,9 @@ from torch import nn from torch.cuda import amp from torch.utils.tensorboard import SummaryWriter + from torchworld.transforms.img import render_color +from torchworld.structures.grid import Grid3d from torchdrive.amp import autocast from torchdrive.autograd import ( @@ -57,6 +59,8 @@ def __init__( dim: int, cam_dim: int, hr_dim: int, + scale: float, + grid_shape: Tuple[int, int, int], cam_features_mask_ratio: float = 0.0, num_encode_frames: int = 3, num_backprop_frames: int = 2, @@ -85,6 +89,9 @@ def __init__( self.camera_encoders = nn.ModuleDict( {cam: compile_fn(cam_encoder()) for cam in cameras} ) + self.grid_shape = grid_shape + self.hr_dim = hr_dim + self.scale = scale for task in tasks.values(): task.set_camera_encoders(self.camera_encoders) @@ -167,27 +174,28 @@ def forward( # run frames in parallel encoder = self.camera_encoders[cam] encoder.train() - inp = batch.color[cam][:, : self.num_encode_frames] + inp = [batch.grid_image(cam, i) for i in range(self.num_encode_frames)] # use gradient checkpointing to save memory - feats = torch.utils.checkpoint.checkpoint( - encoder, inp.flatten(0, 1) - ).unflatten(0, inp.shape[0:2]) - num_frames = feats.size(1) + feats = [ + torch.utils.checkpoint.checkpoint( + encoder, frame_inp + ) for frame_inp in inp + ] + num_frames = self.num_encode_frames to_mask = ( - torch.rand(feats.shape[-2:], device=feats.device) + torch.rand(feats[0].shape[-2:], device=feats[0].device) < self.cam_features_mask_ratio ) - feats[:, :, :, to_mask] = ( - self.cam_mask_value.weight.unsqueeze(0) - .unsqueeze(-1) - .to(feats.dtype) - ) + for feat in feats: + feat[:, :, to_mask] = ( + self.cam_mask_value.weight.unsqueeze(0) + .unsqueeze(-1) + .to(feat.dtype) + ) - for i in range(num_frames): - feat = feats[:, i] - camera_feats[cam].append(feat) + camera_feats[cam] = feats for cam, cam_feats in camera_feats.items(): if torch.is_anomaly_check_nan_enabled(): @@ -195,7 +203,13 @@ def forward( assert not torch.isnan(feats).any().item(), cam with torch.autograd.profiler.record_function("backbone"): - hr_bev, bev_feats, bev_intermediates = self.backbone(camera_feats, batch) + feat = camera_feats[cam][start_frame] + target_grid = Grid3d.from_volume( + data=torch.empty(feat.size(0), self.hr_dim, *self.grid_shape, device=feat.device, dtype=feat.dtype), + voxel_size=1.0/self.scale, + time=feat.time, + ) + hr_bev, bev_feats, bev_intermediates = self.backbone(batch, camera_feats, target_grid) if torch.is_anomaly_check_nan_enabled(): assert not torch.isnan(hr_bev).any().item() diff --git a/torchdrive/train_config.py b/torchdrive/train_config.py index 96b7dc4..422d33e 100644 --- a/torchdrive/train_config.py +++ b/torchdrive/train_config.py @@ -151,12 +151,11 @@ def create_model( from torchworld.models.simplebev_3d import SimpleBEV3DBackbone backbone = SimpleBEV3DBackbone( - grid_shape=adjusted_grid_shape, + grid_shape=self.grid_shape, dim=self.dim, hr_dim=self.hr_dim, cam_dim=self.cam_dim, num_frames=3, - scale=3 / adjust, compile_fn=compile_fn, ) else: @@ -274,6 +273,8 @@ def cam_encoder() -> RegNetEncoder: dim=self.dim, hr_dim=self.hr_dim, cam_dim=self.cam_dim, + grid_shape=self.grid_shape, + scale=3.0, cam_features_mask_ratio=self.cam_features_mask_ratio, compile_fn=compile_fn, num_encode_frames=self.num_encode_frames, diff --git a/torchworld/models/simplebev_3d.py b/torchworld/models/simplebev_3d.py index aaae0cb..7da745a 100644 --- a/torchworld/models/simplebev_3d.py +++ b/torchworld/models/simplebev_3d.py @@ -85,15 +85,16 @@ def forward( x1, x2, x3, x4 = self.fpn(x) assert x1.shape == x.shape - print(self.bev_project) - print(x1.shape, x2.shape, x3.shape, x4.shape) - x0 = x1 # project to BEV grids - x1 = self.bev_project[0](x1.flatten(1, 2)) - x2 = self.bev_project[1](x2.flatten(1, 2)) - x3 = self.bev_project[2](x3.flatten(1, 2)) - x4 = self.bev_project[3](x4.flatten(1, 2)) + x1 = x1.transpose(2, 4).clone().flatten(1, 2) + x1 = self.bev_project[0](x1) + x2 = x2.transpose(2, 4).clone().flatten(1, 2) + x2 = self.bev_project[1](x2) + x3 = x3.transpose(2, 4).clone().flatten(1, 2) + x3 = self.bev_project[2](x3) + x4 = x4.transpose(2, 4).clone().flatten(1, 2) + x4 = self.bev_project[3](x4) return x0, (x1, x2, x3, x4), {} diff --git a/torchworld/structures/grid.py b/torchworld/structures/grid.py index d95abdb..b68a93c 100644 --- a/torchworld/structures/grid.py +++ b/torchworld/structures/grid.py @@ -67,6 +67,9 @@ def __new__( return r + def numpy(self) -> object: + return self._data.numpy() + @classmethod # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. @@ -122,13 +125,6 @@ def __post_init__(self) -> None: f"time must be scalar or 1-dimensional, got {self.time.shape}" ) - T = self.local_to_world.get_matrix() - if (BS := T.size(0)) != 1: - if BS != self._data.size(0): - raise TypeError( - f"data and local_to_world batch sizes don't match: {T.shape, self._data.shape}" - ) - @classmethod def from_volume( cls, @@ -155,6 +151,7 @@ def from_volume( """ device = data.device grid_sizes = tuple(data.shape[2:5]) + print(grid_sizes) locator = VolumeLocator( batch_size=len(data), grid_sizes=grid_sizes, @@ -264,7 +261,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): camera=grid.camera, time=time, mask=mask, - ) + ) if isinstance(out, torch.Tensor) else out for out in out_flat ] out = pytree.tree_unflatten(out_flat, spec) @@ -296,16 +293,11 @@ def __post_init__(self) -> None: f"time must be scalar or 1-dimensional, got {self.time.shape}" ) - T = self.camera.get_projection_transform().get_matrix() - - if (BS := T.size(0)) != 1: - if BS != self._data.size(0): - raise TypeError( - f"data and transform batch sizes don't match: {T.shape, self._data.shape}" - ) - def grid_shape(self) -> Tuple[int, int]: return tuple(self._data.shape[2:4]) def __repr__(self): return f"GridImage(data={self._data}, camera={self.camera}, time={self.time}), mask={self.mask}" + + def numpy(self) -> object: + return self._data.numpy() diff --git a/torchworld/structures/test_grid.py b/torchworld/structures/test_grid.py index 0991698..b754355 100644 --- a/torchworld/structures/test_grid.py +++ b/torchworld/structures/test_grid.py @@ -44,16 +44,17 @@ def test_grid3d_repr(self) -> None: local_to_world=Transform3d(), time=torch.rand(2), ) + str(grid) + repr(grid) - def test_grid3d_conv3d(self) -> None: + def test_grid3d_numpy(self) -> None: grid = Grid3d( data=torch.rand(2, 3, 4, 5, 6), local_to_world=Transform3d(), time=torch.rand(2), ) + grid.numpy() - str(grid) - repr(grid) def test_grid_image(self) -> None: grid = GridImage( @@ -77,6 +78,14 @@ def test_grid_image_repr(self) -> None: str(grid) repr(grid) + def test_grid_image_numpy(self) -> None: + grid = GridImage( + data=torch.rand(2, 3, 4, 5), + camera=PerspectiveCameras(), + time=torch.rand(2), + ) + grid.numpy() + def test_grid_image_permute(self) -> None: grid = GridImage( data=torch.rand(2, 3, 4, 5), diff --git a/torchworld/transforms/simplebev.py b/torchworld/transforms/simplebev.py index d022ece..c40d30f 100644 --- a/torchworld/transforms/simplebev.py +++ b/torchworld/transforms/simplebev.py @@ -29,8 +29,8 @@ def lift_image_to_3d( features: grid with features mask: grid of the mask where the camera could see """ - if dst.numel() != 0: - raise TypeError(f"dst should be batch size zero {dst.shape}") + #if dst.numel() != 0: + # raise TypeError(f"dst should be batch size zero {dst.shape}") device = src.device BS = len(src) @@ -95,7 +95,6 @@ def lift_image_to_3d( ), ) - def merge_grids( grids: Tuple[Grid3d, ...], masks: Tuple[Grid3d, ...],