Skip to content

Commit

Permalink
vista: trajectory conditioning is working
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Jul 22, 2024
1 parent fd4b555 commit e798e57
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 9 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,4 @@ export/
*.trt
*.npy
*.onnx
*.png
16 changes: 10 additions & 6 deletions torchdrive/datasets/nuscenes_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,16 @@ class SampleData(TypedDict):


def normalize01(tensor: Tensor) -> Tensor:
return transforms.functional.normalize(
tensor,
[0.3504, 0.4324, 0.2892],
[0.0863, 0.1097, 0.0764],
inplace=True,
)
"""
rearange to -1 to 1
"""
return tensor * 2.0 - 1.0
# return transforms.functional.normalize(
# tensor,
# [0.3504, 0.4324, 0.2892],
# [0.0863, 0.1097, 0.0764],
# inplace=True,
# )


class CamTypes(StrEnum):
Expand Down
42 changes: 39 additions & 3 deletions torchdrive/models/vista.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch.nn.functional as F

from omegaconf import ListConfig, OmegaConf
from torchvision.transforms.functional import to_pil_image
from torchworld.transforms.img import normalize_img

from vwm.sample_utils import (
do_sample,
Expand All @@ -30,6 +32,18 @@ def __init__(
cond_aug: float = 0.0,
render_size: Tuple[int, int] = (320, 576),
) -> None:
"""
Args:
config_path: path to the config file
ckpt_path: path to the checkpoint file
device: device to run inference on
steps: number of diffusion steps
cfg_scale: scale of the config
num_frames: number of frames to generate
NOTE Vista is trained at 10hz and Nuscenes is 12hz
cond_aug: augmentation strength (extra noise)
render_size: size of the image passed into Vista
"""
config_path = os.path.expanduser(config_path)
ckpt_path = os.path.expanduser(ckpt_path)

Expand Down Expand Up @@ -73,11 +87,22 @@ def generate(
h, w = cond_img.shape[2:]

assert trajectory.size(-1) == 2
trajectory = trajectory.squeeze(0)[1:5].flatten()
# downsample to 4 frames or 2 seconds
trajectory = trajectory.squeeze(0)[1:5]
# switch axis so y+ is forward, and x+ is right
trajectory = torch.stack([-trajectory[:, 1], trajectory[:, 0]], dim=-1)
trajectory = trajectory.flatten()
assert trajectory.shape == (8,)

cond_img = F.interpolate(cond_img, size=self.render_size, mode="bilinear")

amin, amax = cond_img.aminmax()
center = (amax + amin) / 2
dist = (amax - amin) / 2

# recenter to (-1, 1)
cond_img = (cond_img - center) / dist

unique_keys = set([x.input_key for x in self.model.conditioner.embedders])

value_dict = init_embedder_options(unique_keys)
Expand All @@ -99,7 +124,6 @@ def generate(
]

images = cond_img.expand(self.num_frames, -1, -1, -1)
print(images.shape, cond_img.shape, value_dict["trajectory"].shape)

out = do_sample(
images,
Expand All @@ -113,6 +137,13 @@ def generate(
)
samples, samples_z, inputs = out

out_min, out_max = samples.aminmax()
out_center = (out_max + out_min) / 2
out_dist = (out_max - out_min) / 2

# restore original range
samples = (samples - out_center) / out_dist * dist + center

return F.interpolate(samples, (h, w), mode="bilinear")


Expand All @@ -136,8 +167,13 @@ def generate(
trajectory = trajectory[:, ::6, ::2]
cond_img = batch.color["CAM_FRONT"][:1, 0]

print(trajectory)

sampler = VistaSampler(device=device)
out = sampler.generate(cond_img, trajectory)
print(out.shape)
assert out.shape == (10, 3, 480, 640)
print(out)

for i, img in enumerate(out):
img = to_pil_image(normalize_img(img))
img.save(f"vista_{i}.png")

0 comments on commit e798e57

Please sign in to comment.