Skip to content

Commit

Permalink
Adding dinov2 features
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinSchmid7 committed Jan 24, 2024
1 parent 08d1c1b commit c7aff74
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 16 deletions.
4 changes: 2 additions & 2 deletions wild_visual_navigation/cfg/experiment_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ class AblationDataModuleParams:

@dataclass
class ModelParams:
name: str = "LinearRnvp" # LinearRnvp, SimpleMLP, SimpleGCN, DoubleMLP
name: str = "SimpleMLP" # LinearRnvp, SimpleMLP, SimpleGCN, DoubleMLP
load_ckpt: Optional[str] = None

@dataclass
class SimpleMlpCfgParams:
input_size: int = 384
input_size: int = 768 # 384
hidden_sizes: List[int] = field(default_factory=lambda: [256, 32, 1])
reconstruction: bool = True

Expand Down
1 change: 1 addition & 0 deletions wild_visual_navigation/feature_extractor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .dino_interface import DinoInterface, run_dino_interfacer
from .dino2_interface import Dino2Interface
from .torchvision_interface import TorchVisionInterface

# from .dino_trt_interface import DinoTrtInterface, TrtModel, run_dino_trt_interfacer
Expand Down
83 changes: 83 additions & 0 deletions wild_visual_navigation/feature_extractor/dino2_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from wild_visual_navigation import WVN_ROOT_DIR
import torch
from torchvision import transforms as T
import torch.nn.functional as F
import time
from typing import Tuple
from torch.cuda.amp import autocast


class Dino2Interface:
def __init__(
self,
device: str,
model_type: str = "vit_small",
**kwargs
):

self._model_type = model_type
# Initialize DINOv2
if self._model_type == "vit_small":
self.model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg')
self.embed_dim = 384
elif self._model_type == "vit_base":
self.model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg')
self.embed_dim = 768
elif self._model_type == "vit_large":
self.model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg')
self.embed_dim = 1024
elif self._model_type == "vit_huge":
self.model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg')
self.embed_dim = 1536
self.patch_size = kwargs.get("patch_size", 14)
# Send to device
self.model.to(device)
self.device = device

self.transform = self._create_transform()

def _create_transform(self):
# Resize and then center crop to the expected input size
transform = T.Compose([
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
return transform

@torch.no_grad()
def inference(self, img: torch.tensor):
# check if it has a batch dim or not
if img.dim() == 3:
img = img.unsqueeze(0)

# Resize and normalize
img = self.transform(img)
# Send to device
img = img.to(self.device)
# print("After transform shape is:",img.shape)
# Inference
with autocast():
feat = self.model.forward_features(img)["x_norm_patchtokens"]
B = feat.shape[0]
C = feat.shape[2]
H = int(img.shape[2] / self.patch_size)
W = int(img.shape[3] / self.patch_size)
feat = feat.permute(0, 2, 1)
feat = feat.reshape(B, C, H, W)

# resize and interpolate features
B, D, H, W = img.shape
new_size = (H, H)
pad = int((W - H) / 2)
feat = F.interpolate(feat, new_size, mode="bilinear", align_corners=True)
feat = F.pad(feat, pad=[pad, pad, 0, 0])

return feat

def change_device(self, device):
"""Changes the device of all the class members
Args:
device (str): new device
"""
self.model.to(device)
self.device = device
13 changes: 12 additions & 1 deletion wild_visual_navigation/feature_extractor/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
DinoInterface,
SegmentExtractor,
TorchVisionInterface,
Dino2Interface,
)
from pytictac import Timer
import skimage
Expand Down Expand Up @@ -38,8 +39,9 @@ def __init__(
self.extractor = StegoInterface(device=device, input_size=input_size)
elif self._feature_type == "dino":
self._feature_dim = 90

self.extractor = DinoInterface(device=device, input_size=input_size, patch_size=kwargs.get("patch_size", 8), dim=kwargs.get("dino_dim", 384))
elif self._feature_type == "dino2":
self.extractor = Dino2Interface(device=device, model_type="vit_base")
elif self._feature_type == "sift":
self._feature_dim = 128
self.extractor = DenseSIFTDescriptor().to(device)
Expand Down Expand Up @@ -224,6 +226,9 @@ def compute_features(self, img: torch.tensor, seg: torch.tensor, center: torch.t
elif self._feature_type == "dino":
feat = self.compute_dino(img, seg, center, **kwargs)

elif self._feature_type == "dino2":
feat = self.compute_dino2(img, seg, center, **kwargs)

elif self._feature_type == "stego":
feat = self.compute_stego(img, seg, center, **kwargs)

Expand Down Expand Up @@ -257,6 +262,12 @@ def compute_dino(self, img: torch.tensor, seg: torch.tensor, center: torch.tenso
features = self.extractor.inference(img_internal)
return features

@torch.no_grad()
def compute_dino2(self, img: torch.tensor, seg: torch.tensor, center: torch.tensor, **kwargs):
img_internal = img.clone()
features = self.extractor.inference(img_internal)
return features

@torch.no_grad()
def compute_torchvision(self, img: torch.tensor, seg: torch.tensor, center: torch.tensor, **kwargs):
img_internal = img.clone()
Expand Down
Loading

0 comments on commit c7aff74

Please sign in to comment.