Skip to content

Commit

Permalink
apply code review suggestion
Browse files Browse the repository at this point in the history
  • Loading branch information
ducha-aiki committed Jan 23, 2024
1 parent dcc4288 commit fb6f1ac
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 31 deletions.
1 change: 1 addition & 0 deletions lightglue/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
from .lightglue import LightGlue # noqa
from .sift import SIFT # noqa
from .superpoint import SuperPoint # noqa
from .dog_hardnet import DoGHardNet # noqa
from .utils import match_pair # noqa
39 changes: 8 additions & 31 deletions lightglue/dog_hardnet.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,30 @@
import warnings

import cv2
import numpy as np
import torch
from kornia.color import rgb_to_grayscale
from kornia.feature import HardNet, LAFDescriptor, laf_from_center_scale_ori
from packaging import version

from .utils import Extractor
from .sift import SIFT

try:
import pycolmap
except ImportError:
pycolmap = None


class DoGHardNet(SIFT):
default_conf = {
"nms_radius": 0, # None to disable filtering entirely.
"max_num_keypoints": 2048,
"backend": "opencv", # in {opencv, pycolmap, pycolmap_cpu, pycolmap_cuda}
"detection_threshold": -1, # from COLMAP
"edge_threshold": -1,
"first_octave": -1, # only used by pycolmap, the default of COLMAP
"num_octaves": 4,
"force_num_keypoints": True,
}

required_data_keys = ["image"]

def _init(self, conf):
super()._init(conf)
def __init__(self, **conf):
super().__init__(**conf)
self.laf_desc = LAFDescriptor(HardNet(True)).eval()


def _forward(self, data: dict) -> dict:
def forward(self, data: dict) -> dict:
image = data["image"]
if image.shape[1] == 3:
image = rgb_to_grayscale(image)
device = image.device
self.laf_desc = self.laf_desc.to(device)
self.laf_desc.descriptor = self.laf_desc.descriptor.eval()
pred = []
im_size = data.get("image_size").long()
if "image_size" in data.keys():
im_size = data.get("image_size").long()
else:
im_size = None
for k in range(len(image)):
img = image[k]
if im_size is not None:
Expand All @@ -58,8 +39,4 @@ def _forward(self, data: dict) -> dict:
p["descriptors"] = self.laf_desc(img[None], lafs).reshape(-1, 128)
pred.append(p)
pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]}
return pred




return pred

0 comments on commit fb6f1ac

Please sign in to comment.