Skip to content

Commit

Permalink
more intermediate work to switch to torchworld
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Apr 6, 2024
1 parent d9f12d6 commit e2283bc
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 47 deletions.
2 changes: 1 addition & 1 deletion configs/simplebev3d_multi_pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
autolabel_path=None, # "/mnt/ext3/autolabel2",
mask_path="n/a", # only used for rice dataset
num_workers=4,
batch_size=6,
batch_size=4,
# tasks
det=False,
ae=False,
Expand Down
44 changes: 29 additions & 15 deletions torchdrive/tasks/bev.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from torch import nn
from torch.cuda import amp
from torch.utils.tensorboard import SummaryWriter

from torchworld.transforms.img import render_color
from torchworld.structures.grid import Grid3d

from torchdrive.amp import autocast
from torchdrive.autograd import (
Expand Down Expand Up @@ -57,6 +59,8 @@ def __init__(
dim: int,
cam_dim: int,
hr_dim: int,
scale: float,
grid_shape: Tuple[int, int, int],
cam_features_mask_ratio: float = 0.0,
num_encode_frames: int = 3,
num_backprop_frames: int = 2,
Expand Down Expand Up @@ -85,6 +89,9 @@ def __init__(
self.camera_encoders = nn.ModuleDict(
{cam: compile_fn(cam_encoder()) for cam in cameras}
)
self.grid_shape = grid_shape
self.hr_dim = hr_dim
self.scale = scale

for task in tasks.values():
task.set_camera_encoders(self.camera_encoders)
Expand Down Expand Up @@ -167,35 +174,42 @@ def forward(
# run frames in parallel
encoder = self.camera_encoders[cam]
encoder.train()
inp = batch.color[cam][:, : self.num_encode_frames]
inp = [batch.grid_image(cam, i) for i in range(self.num_encode_frames)]

# use gradient checkpointing to save memory
feats = torch.utils.checkpoint.checkpoint(
encoder, inp.flatten(0, 1)
).unflatten(0, inp.shape[0:2])
num_frames = feats.size(1)
feats = [
torch.utils.checkpoint.checkpoint(
encoder, frame_inp
) for frame_inp in inp
]
num_frames = self.num_encode_frames

to_mask = (
torch.rand(feats.shape[-2:], device=feats.device)
torch.rand(feats[0].shape[-2:], device=feats[0].device)
< self.cam_features_mask_ratio
)
feats[:, :, :, to_mask] = (
self.cam_mask_value.weight.unsqueeze(0)
.unsqueeze(-1)
.to(feats.dtype)
)
for feat in feats:
feat[:, :, to_mask] = (
self.cam_mask_value.weight.unsqueeze(0)
.unsqueeze(-1)
.to(feat.dtype)
)

for i in range(num_frames):
feat = feats[:, i]
camera_feats[cam].append(feat)
camera_feats[cam] = feats

for cam, cam_feats in camera_feats.items():
if torch.is_anomaly_check_nan_enabled():
for feats in cam_feats:
assert not torch.isnan(feats).any().item(), cam

with torch.autograd.profiler.record_function("backbone"):
hr_bev, bev_feats, bev_intermediates = self.backbone(camera_feats, batch)
feat = camera_feats[cam][start_frame]
target_grid = Grid3d.from_volume(
data=torch.empty(feat.size(0), self.hr_dim, *self.grid_shape, device=feat.device, dtype=feat.dtype),
voxel_size=1.0/self.scale,
time=feat.time,
)
hr_bev, bev_feats, bev_intermediates = self.backbone(batch, camera_feats, target_grid)

if torch.is_anomaly_check_nan_enabled():
assert not torch.isnan(hr_bev).any().item()
Expand Down
5 changes: 3 additions & 2 deletions torchdrive/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,11 @@ def create_model(
from torchworld.models.simplebev_3d import SimpleBEV3DBackbone

backbone = SimpleBEV3DBackbone(
grid_shape=adjusted_grid_shape,
grid_shape=self.grid_shape,
dim=self.dim,
hr_dim=self.hr_dim,
cam_dim=self.cam_dim,
num_frames=3,
scale=3 / adjust,
compile_fn=compile_fn,
)
else:
Expand Down Expand Up @@ -274,6 +273,8 @@ def cam_encoder() -> RegNetEncoder:
dim=self.dim,
hr_dim=self.hr_dim,
cam_dim=self.cam_dim,
grid_shape=self.grid_shape,
scale=3.0,
cam_features_mask_ratio=self.cam_features_mask_ratio,
compile_fn=compile_fn,
num_encode_frames=self.num_encode_frames,
Expand Down
15 changes: 8 additions & 7 deletions torchworld/models/simplebev_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,16 @@ def forward(
x1, x2, x3, x4 = self.fpn(x)
assert x1.shape == x.shape

print(self.bev_project)
print(x1.shape, x2.shape, x3.shape, x4.shape)

x0 = x1

# project to BEV grids
x1 = self.bev_project[0](x1.flatten(1, 2))
x2 = self.bev_project[1](x2.flatten(1, 2))
x3 = self.bev_project[2](x3.flatten(1, 2))
x4 = self.bev_project[3](x4.flatten(1, 2))
x1 = x1.transpose(2, 4).clone().flatten(1, 2)
x1 = self.bev_project[0](x1)
x2 = x2.transpose(2, 4).clone().flatten(1, 2)
x2 = self.bev_project[1](x2)
x3 = x3.transpose(2, 4).clone().flatten(1, 2)
x3 = self.bev_project[2](x3)
x4 = x4.transpose(2, 4).clone().flatten(1, 2)
x4 = self.bev_project[3](x4)

return x0, (x1, x2, x3, x4), {}
24 changes: 8 additions & 16 deletions torchworld/structures/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def __new__(

return r

def numpy(self) -> object:
return self._data.numpy()

@classmethod
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
Expand Down Expand Up @@ -122,13 +125,6 @@ def __post_init__(self) -> None:
f"time must be scalar or 1-dimensional, got {self.time.shape}"
)

T = self.local_to_world.get_matrix()
if (BS := T.size(0)) != 1:
if BS != self._data.size(0):
raise TypeError(
f"data and local_to_world batch sizes don't match: {T.shape, self._data.shape}"
)

@classmethod
def from_volume(
cls,
Expand All @@ -155,6 +151,7 @@ def from_volume(
"""
device = data.device
grid_sizes = tuple(data.shape[2:5])
print(grid_sizes)
locator = VolumeLocator(
batch_size=len(data),
grid_sizes=grid_sizes,
Expand Down Expand Up @@ -264,7 +261,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
camera=grid.camera,
time=time,
mask=mask,
)
) if isinstance(out, torch.Tensor) else out
for out in out_flat
]
out = pytree.tree_unflatten(out_flat, spec)
Expand Down Expand Up @@ -296,16 +293,11 @@ def __post_init__(self) -> None:
f"time must be scalar or 1-dimensional, got {self.time.shape}"
)

T = self.camera.get_projection_transform().get_matrix()

if (BS := T.size(0)) != 1:
if BS != self._data.size(0):
raise TypeError(
f"data and transform batch sizes don't match: {T.shape, self._data.shape}"
)

def grid_shape(self) -> Tuple[int, int]:
return tuple(self._data.shape[2:4])

def __repr__(self):
return f"GridImage(data={self._data}, camera={self.camera}, time={self.time}), mask={self.mask}"

def numpy(self) -> object:
return self._data.numpy()
15 changes: 12 additions & 3 deletions torchworld/structures/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,17 @@ def test_grid3d_repr(self) -> None:
local_to_world=Transform3d(),
time=torch.rand(2),
)
str(grid)
repr(grid)

def test_grid3d_conv3d(self) -> None:
def test_grid3d_numpy(self) -> None:
grid = Grid3d(
data=torch.rand(2, 3, 4, 5, 6),
local_to_world=Transform3d(),
time=torch.rand(2),
)
grid.numpy()

str(grid)
repr(grid)

def test_grid_image(self) -> None:
grid = GridImage(
Expand All @@ -77,6 +78,14 @@ def test_grid_image_repr(self) -> None:
str(grid)
repr(grid)

def test_grid_image_numpy(self) -> None:
grid = GridImage(
data=torch.rand(2, 3, 4, 5),
camera=PerspectiveCameras(),
time=torch.rand(2),
)
grid.numpy()

def test_grid_image_permute(self) -> None:
grid = GridImage(
data=torch.rand(2, 3, 4, 5),
Expand Down
5 changes: 2 additions & 3 deletions torchworld/transforms/simplebev.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def lift_image_to_3d(
features: grid with features
mask: grid of the mask where the camera could see
"""
if dst.numel() != 0:
raise TypeError(f"dst should be batch size zero {dst.shape}")
#if dst.numel() != 0:
# raise TypeError(f"dst should be batch size zero {dst.shape}")

device = src.device
BS = len(src)
Expand Down Expand Up @@ -95,7 +95,6 @@ def lift_image_to_3d(
),
)


def merge_grids(
grids: Tuple[Grid3d, ...],
masks: Tuple[Grid3d, ...],
Expand Down

0 comments on commit e2283bc

Please sign in to comment.