Skip to content

Commit

Permalink
Add Extractor, cleanup SIFT, bugfixes in LightGlue
Browse files Browse the repository at this point in the history
  • Loading branch information
sarlinpe committed Oct 18, 2023
1 parent 2777e03 commit d348f25
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 233 deletions.
27 changes: 4 additions & 23 deletions lightglue/aliked.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
# Xiaoming Zhao, Xingming Wu, Weihai Chen, Peter C.Y. Chen, Qingsong Xu, and Zhengguo Li
# Code from https://github.com/Shiaoming/ALIKED

from types import SimpleNamespace
from typing import Callable, Optional

import torch
Expand All @@ -42,7 +41,7 @@
from torch.nn.modules.utils import _pair
from torchvision.models import resnet

from .utils import ImagePreprocessor
from .utils import Extractor


def get_patches(
Expand Down Expand Up @@ -609,12 +608,11 @@ def forward(self, x, keypoints):
return descriptors, offsets


class ALIKED(nn.Module):
class ALIKED(Extractor):
default_conf = {
"model_name": "aliked-n16",
"max_num_keypoints": -1,
"detection_threshold": 0.2,
"force_num_keypoints": False,
"nms_radius": 2,
}

Expand All @@ -630,20 +628,15 @@ class ALIKED(nn.Module):
"aliked-n32": [16, 32, 64, 128, 128, 3, 32],
}
preprocess_conf = {
**ImagePreprocessor.default_conf,
"resize": 1024,
"grayscale": False,
}

required_data_keys = ["image"]

def __init__(self, **conf):
super().__init__()
self.conf = {**self.default_conf, **conf}
conf = self.conf = SimpleNamespace(**self.conf)
if conf.force_num_keypoints:
assert conf.detection_threshold <= 0 and conf.max_num_keypoints > 0
# get configurations
super().__init__(**conf) # Update with default configuration.
conf = self.conf
c1, c2, c3, c4, dim, K, M = self.cfgs[conf.model_name]
conv_types = ["conv", "conv", "dcn", "dcn"]
conv2D = False
Expand Down Expand Up @@ -761,15 +754,3 @@ def forward(self, data: dict) -> dict:
"descriptors": torch.stack(descriptors), # B x N x D
"keypoint_scores": torch.stack(kptscores), # B x N
}

def extract(self, img: torch.Tensor, **conf) -> dict:
"""Perform extraction with online resizing"""
if img.dim() == 3:
img = img[None] # add batch dim
assert img.dim() == 4 and img.shape[0] == 1
shape = img.shape[-2:][::-1]
img, scales = ImagePreprocessor(**{**self.preprocess_conf, **conf})(img)
feats = self.forward({"image": img})
feats["image_size"] = torch.tensor(shape)[None].to(img).float()
feats["keypoints"] = (feats["keypoints"] + 0.5) / scales[None] - 0.5
return feats
24 changes: 3 additions & 21 deletions lightglue/disk.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from types import SimpleNamespace

import kornia
import torch
import torch.nn as nn

from .utils import ImagePreprocessor
from .utils import Extractor


class DISK(nn.Module):
class DISK(Extractor):
default_conf = {
"weights": "depth",
"max_num_keypoints": None,
Expand All @@ -18,17 +15,14 @@ class DISK(nn.Module):
}

preprocess_conf = {
**ImagePreprocessor.default_conf,
"resize": 1024,
"grayscale": False,
}

required_data_keys = ["image"]

def __init__(self, **conf) -> None:
super().__init__()
self.conf = {**self.default_conf, **conf}
self.conf = SimpleNamespace(**self.conf)
super().__init__(**conf) # Update with default configuration.
self.model = kornia.feature.DISK.from_pretrained(self.conf.weights)

def forward(self, data: dict) -> dict:
Expand Down Expand Up @@ -57,15 +51,3 @@ def forward(self, data: dict) -> dict:
"keypoint_scores": scores.to(image).contiguous(),
"descriptors": descriptors.to(image).contiguous(),
}

def extract(self, img: torch.Tensor, **conf) -> dict:
"""Perform extraction with online resizing"""
if img.dim() == 3:
img = img[None] # add batch dim
assert img.dim() == 4 and img.shape[0] == 1
shape = img.shape[-2:][::-1]
img, scales = ImagePreprocessor(**{**self.preprocess_conf, **conf})(img)
feats = self.forward({"image": img})
feats["image_size"] = torch.tensor(shape)[None].to(img).float()
feats["keypoints"] = (feats["keypoints"] + 0.5) / scales[None] - 0.5
return feats
20 changes: 14 additions & 6 deletions lightglue/lightglue.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,11 +361,15 @@ class LightGlue(nn.Module):

def __init__(self, features="superpoint", **conf) -> None:
super().__init__()
self.conf = {**self.default_conf, **conf}
self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf})
if features is not None:
assert features in list(self.features.keys())
self.conf["weights"], self.conf["input_dim"] = self.features[features]
self.conf = conf = SimpleNamespace(**self.conf)
if features not in self.features:
raise ValueError(
f"Unsupported features: {features} not in "
f"{{{','.join(self.features)}}}"
)
for k, v in self.features[features].items():
setattr(conf, k, v)

if conf.input_dim != conf.descriptor_dim:
self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True)
Expand Down Expand Up @@ -471,8 +475,12 @@ def _forward(self, data: dict) -> dict:
kpts1 = normalize_keypoints(kpts1, size1).clone()

if self.conf.add_scale_ori:
kpts0 = torch.cat([kpts0, data0["scales"], data0["oris"]], -1)
kpts1 = torch.cat([kpts1, data1["scales"], data1["oris"]], -1)
kpts0 = torch.cat(
[kpts0] + [data0[k].unsqueeze(-1) for k in ("scales", "oris")], -1
)
kpts1 = torch.cat(
[kpts1] + [data1[k].unsqueeze(-1) for k in ("scales", "oris")], -1
)
desc0 = data0["descriptors"].detach().contiguous()
desc1 = data1["descriptors"].detach().contiguous()

Expand Down
Loading

0 comments on commit d348f25

Please sign in to comment.