Skip to content

Commit

Permalink
- support sam box prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
yatengLG committed Dec 11, 2024
1 parent 2e76228 commit c3a2911
Show file tree
Hide file tree
Showing 10 changed files with 350 additions and 145 deletions.
1 change: 1 addition & 0 deletions ISAT/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class STATUSMode(Enum):
class DRAWMode(Enum):
POLYGON = 0
SEGMENTANYTHING = 1
SEGMENTANYTHING_BOX = 2

class CLICKMode(Enum):
POSITIVE = 0
Expand Down
1 change: 1 addition & 0 deletions ISAT/icons.qrc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
<RCC>
<qresource prefix="icon">
<file>../icons/小矩形_rectangle-small.svg</file>
<file>../icons/关闭-小_close-small.svg</file>
<file>../icons/校验-小_check-small.svg</file>
<file>../icons/play-1.svg</file>
Expand Down
283 changes: 157 additions & 126 deletions ISAT/icons_rc.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion ISAT/software.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ label:
software:
auto_save: false
contour_mode: external
language: zh
language: en
mask_alpha: 0.6
show_edge: true
show_prompt: true
Expand Down
15 changes: 12 additions & 3 deletions ISAT/ui/MainWindow.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,11 @@ def setupUi(self, MainWindow):
icon35.addPixmap(QtGui.QPixmap(":/icon/icons/视频_video-two.svg"), QtGui.QIcon.Normal, QtGui.QIcon.Off)
self.actionVideo_to_frames.setIcon(icon35)
self.actionVideo_to_frames.setObjectName("actionVideo_to_frames")
self.actionSegment_anything_box = QtWidgets.QAction(MainWindow)
icon36 = QtGui.QIcon()
icon36.addPixmap(QtGui.QPixmap(":/icon/icons/小矩形_rectangle-small.svg"), QtGui.QIcon.Normal, QtGui.QIcon.Off)
self.actionSegment_anything_box.setIcon(icon36)
self.actionSegment_anything_box.setObjectName("actionSegment_anything_box")
self.menuFile.addAction(self.actionOpen_dir)
self.menuFile.addAction(self.actionSave_dir)
self.menuFile.addSeparator()
Expand All @@ -388,6 +393,7 @@ def setupUi(self, MainWindow):
self.menuTools.addAction(self.actionAuto_segment)
self.menuTools.addAction(self.actionAnno_validator)
self.menuEdit.addAction(self.actionSegment_anything)
self.menuEdit.addAction(self.actionSegment_anything_box)
self.menuEdit.addAction(self.actionPolygon)
self.menuEdit.addSeparator()
self.menuEdit.addAction(self.actionVideo_segment)
Expand Down Expand Up @@ -428,6 +434,7 @@ def setupUi(self, MainWindow):
self.toolBar.addAction(self.actionNext)
self.toolBar.addSeparator()
self.toolBar.addAction(self.actionSegment_anything)
self.toolBar.addAction(self.actionSegment_anything_box)
self.toolBar.addAction(self.actionPolygon)
self.toolBar.addSeparator()
self.toolBar.addAction(self.actionVideo_segment_once)
Expand Down Expand Up @@ -507,8 +514,8 @@ def retranslateUi(self, MainWindow):
self.actionNext.setShortcut(_translate("MainWindow", "D"))
self.actionShortcut.setText(_translate("MainWindow", "Shortcut"))
self.actionAbout.setText(_translate("MainWindow", "About"))
self.actionSegment_anything.setText(_translate("MainWindow", "Segment anything"))
self.actionSegment_anything.setToolTip(_translate("MainWindow", "Segment anything"))
self.actionSegment_anything.setText(_translate("MainWindow", "Segment anything point"))
self.actionSegment_anything.setToolTip(_translate("MainWindow", "Segment anything point"))
self.actionSegment_anything.setStatusTip(_translate("MainWindow", "Quick annotate using Segment anything."))
self.actionSegment_anything.setShortcut(_translate("MainWindow", "Q"))
self.actionDelete.setText(_translate("MainWindow", "Delete"))
Expand Down Expand Up @@ -588,4 +595,6 @@ def retranslateUi(self, MainWindow):
self.actionVideo_segment_five_times.setText(_translate("MainWindow", "Video segment five times"))
self.actionVideo_segment_five_times.setStatusTip(_translate("MainWindow", "Video segment next five frames.(only support sam2 model)"))
self.actionVideo_to_frames.setText(_translate("MainWindow", "Video to frames"))

self.actionSegment_anything_box.setText(_translate("MainWindow", "Segment anything box"))
self.actionSegment_anything_box.setStatusTip(_translate("MainWindow", "Quick annotate using Segment anything."))
self.actionSegment_anything_box.setShortcut(_translate("MainWindow", "W"))
21 changes: 19 additions & 2 deletions ISAT/ui/MainWindow.ui
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@
<string>Edit</string>
</property>
<addaction name="actionSegment_anything"/>
<addaction name="actionSegment_anything_box"/>
<addaction name="actionPolygon"/>
<addaction name="separator"/>
<addaction name="actionVideo_segment"/>
Expand Down Expand Up @@ -266,6 +267,7 @@
<addaction name="actionNext"/>
<addaction name="separator"/>
<addaction name="actionSegment_anything"/>
<addaction name="actionSegment_anything_box"/>
<addaction name="actionPolygon"/>
<addaction name="separator"/>
<addaction name="actionVideo_segment_once"/>
Expand Down Expand Up @@ -528,10 +530,10 @@
<normaloff>:/icon/icons/M_Favicon.ico</normaloff>:/icon/icons/M_Favicon.ico</iconset>
</property>
<property name="text">
<string>Segment anything</string>
<string>Segment anything point</string>
</property>
<property name="toolTip">
<string>Segment anything</string>
<string>Segment anything point</string>
</property>
<property name="statusTip">
<string>Quick annotate using Segment anything.</string>
Expand Down Expand Up @@ -993,6 +995,21 @@
<string>Video to frames</string>
</property>
</action>
<action name="actionSegment_anything_box">
<property name="icon">
<iconset resource="../icons.qrc">
<normaloff>:/icon/icons/小矩形_rectangle-small.svg</normaloff>:/icon/icons/小矩形_rectangle-small.svg</iconset>
</property>
<property name="text">
<string>Segment anything box</string>
</property>
<property name="statusTip">
<string>Quick annotate using Segment anything.</string>
</property>
<property name="shortcut">
<string>W</string>
</property>
</action>
</widget>
<resources>
<include location="../icons.qrc"/>
Expand Down
67 changes: 55 additions & 12 deletions ISAT/widgets/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# @Author : LG

from PyQt5 import QtWidgets, QtGui, QtCore
from ISAT.widgets.polygon import Polygon, Vertex, PromptPoint, Line
from ISAT.widgets.polygon import Polygon, Vertex, PromptPoint, Line, Rect
from ISAT.configs import STATUSMode, CLICKMode, DRAWMode, CONTOURMode
import numpy as np
import cv2
Expand All @@ -18,6 +18,7 @@ def __init__(self, mainwindow):
self.mask_item: QtWidgets.QGraphicsPixmapItem = None
self.image_data = None
self.current_graph: Polygon = None
self.current_sam_rect: Rect = None
self.current_line: Line = None
self.mode = STATUSMode.VIEW
self.click = CLICKMode.POSITIVE
Expand Down Expand Up @@ -82,6 +83,7 @@ def change_mode_to_create(self):
self.mainwindow.actionNext.setEnabled(False)

self.mainwindow.actionSegment_anything.setEnabled(False)
self.mainwindow.actionSegment_anything_box.setEnabled(False)
self.mainwindow.actionPolygon.setEnabled(False)
self.mainwindow.actionBackspace.setEnabled(True)
self.mainwindow.actionFinish.setEnabled(True)
Expand Down Expand Up @@ -157,6 +159,7 @@ def change_mode_to_edit(self):
self.mainwindow.actionNext.setEnabled(False)

self.mainwindow.actionSegment_anything.setEnabled(False)
self.mainwindow.actionSegment_anything_box.setEnabled(False)
self.mainwindow.actionPolygon.setEnabled(False)
self.mainwindow.actionBackspace.setEnabled(False)
self.mainwindow.actionFinish.setEnabled(False)
Expand Down Expand Up @@ -198,6 +201,7 @@ def change_mode_to_repaint(self):
self.mainwindow.actionNext.setEnabled(False)

self.mainwindow.actionSegment_anything.setEnabled(False)
self.mainwindow.actionSegment_anything_box.setEnabled(False)
self.mainwindow.actionPolygon.setEnabled(False)
self.mainwindow.actionBackspace.setEnabled(True)
self.mainwindow.actionFinish.setEnabled(False)
Expand Down Expand Up @@ -242,6 +246,10 @@ def start_segment_anything(self):
self.draw_mode = DRAWMode.SEGMENTANYTHING
self.start_draw()

def start_segment_anything_box(self):
self.draw_mode = DRAWMode.SEGMENTANYTHING_BOX
self.start_draw()

def start_draw_polygon(self):
self.draw_mode = DRAWMode.POLYGON
self.start_draw()
Expand Down Expand Up @@ -269,7 +277,7 @@ def finish_draw(self):
is_crowd = False
note = ''

if self.draw_mode == DRAWMode.SEGMENTANYTHING:
if self.draw_mode == DRAWMode.SEGMENTANYTHING or self.draw_mode == DRAWMode.SEGMENTANYTHING_BOX:
# mask to polygon
# --------------
if self.masks is not None:
Expand Down Expand Up @@ -378,6 +386,12 @@ def finish_draw(self):
self.mainwindow.annos_dock_widget.update_listwidget()

self.current_graph = None

if self.current_sam_rect is not None:
self.current_sam_rect.delete()
self.removeItem(self.current_sam_rect)
self.current_sam_rect = None

self.change_mode_to_view()

# mask清空
Expand Down Expand Up @@ -406,6 +420,11 @@ def cancel_draw(self):
for item in self.selectedItems():
item.setSelected(False)

if self.current_sam_rect is not None:
self.current_sam_rect.delete()
self.removeItem(self.current_sam_rect)
self.current_sam_rect = None

self.change_mode_to_view()

self.click_points.clear()
Expand Down Expand Up @@ -763,6 +782,14 @@ def mousePressEvent(self, event: 'QtWidgets.QGraphicsSceneMouseEvent'):
self.prompt_points.append(prompt_point)
self.addItem(prompt_point)

elif self.draw_mode == DRAWMode.SEGMENTANYTHING_BOX: # sam 矩形框提示
if self.current_sam_rect is None:
self.current_sam_rect = Rect()
self.current_sam_rect.setZValue(2)
self.addItem(self.current_sam_rect)
self.current_sam_rect.addPoint(QtCore.QPointF(sceneX, sceneY))
self.current_sam_rect.addPoint(QtCore.QPointF(sceneX, sceneY))

elif self.draw_mode == DRAWMode.POLYGON:
# 移除随鼠标移动的点
self.current_graph.removePoint(len(self.current_graph.points) - 1)
Expand Down Expand Up @@ -877,6 +904,10 @@ def mouseMoveEvent(self, event: 'QtWidgets.QGraphicsSceneMouseEvent'):
if self.draw_mode == DRAWMode.POLYGON:
# 随鼠标位置实时更新多边形
self.current_graph.movePoint(len(self.current_graph.points) - 1, pos)
if self.draw_mode == DRAWMode.SEGMENTANYTHING_BOX:
if self.current_sam_rect is not None:
self.current_sam_rect.movePoint(len(self.current_sam_rect.points) - 1, pos)
self.update_mask()

if self.mode == STATUSMode.REPAINT:
self.current_line.movePoint(len(self.current_line.points) - 1, pos)
Expand Down Expand Up @@ -946,6 +977,23 @@ def update_mask(self):

if len(self.click_points) > 0 and len(self.click_points_mode) > 0:
masks = self.mainwindow.segany.predict_with_point_prompt(self.click_points, self.click_points_mode)
self.masks = masks
color = np.array([0, 0, 255])
h, w = masks.shape[-2:]
mask_image = masks.reshape(h, w, 1) * color.reshape(1, 1, -1)
mask_image = mask_image.astype("uint8")
mask_image = cv2.cvtColor(mask_image, cv2.COLOR_BGR2RGB)
mask_image = cv2.addWeighted(self.image_data, self.mask_alpha, mask_image, 1, 0)
elif self.current_sam_rect is not None:
point1 = self.current_sam_rect.points[0]
point2 = self.current_sam_rect.points[1]
box = np.array([min(point1.x(), point2.x()),
min(point1.y(), point2.y()),
max(point1.x(), point2.x()),
max(point1.y(), point2.y()),
])
masks = self.mainwindow.segany.predict_with_box_prompt(box)

self.masks = masks
color = np.array([0, 0, 255])
h, w = masks.shape[-2:]
Expand All @@ -954,19 +1002,14 @@ def update_mask(self):
mask_image = cv2.cvtColor(mask_image, cv2.COLOR_BGR2RGB)
# 这里通过调整原始图像的权重self.mask_alpha,来调整mask的明显程度。
mask_image = cv2.addWeighted(self.image_data, self.mask_alpha, mask_image, 1, 0)
mask_image = QtGui.QImage(mask_image[:], mask_image.shape[1], mask_image.shape[0], mask_image.shape[1] * 3,
QtGui.QImage.Format_RGB888)
mask_pixmap = QtGui.QPixmap(mask_image)
if self.mask_item is not None:
self.mask_item.setPixmap(mask_pixmap)
else:
mask_image = np.zeros(self.image_data.shape, dtype=np.uint8)
mask_image = cv2.addWeighted(self.image_data, 1, mask_image, 0, 0)
mask_image = QtGui.QImage(mask_image[:], mask_image.shape[1], mask_image.shape[0], mask_image.shape[1] * 3,
QtGui.QImage.Format_RGB888)
mask_pixmap = QtGui.QPixmap(mask_image)
if self.mask_item is not None:
self.mask_item.setPixmap(mask_pixmap)
mask_image = QtGui.QImage(mask_image[:], mask_image.shape[1], mask_image.shape[0], mask_image.shape[1] * 3,
QtGui.QImage.Format_RGB888)
mask_pixmap = QtGui.QPixmap(mask_image)
if self.mask_item is not None:
self.mask_item.setPixmap(mask_pixmap)

def backspace(self):
if self.mode == STATUSMode.CREATE:
Expand Down
8 changes: 8 additions & 0 deletions ISAT/widgets/mainwindow.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ def init_segment_anything(self, model_name=None):
return
# 等待sam线程完成
self.actionSegment_anything.setEnabled(False)
self.actionSegment_anything_box.setEnabled(False)
try:
self.seganythread.wait()
self.seganythread.results_dict.clear()
Expand Down Expand Up @@ -539,6 +540,7 @@ def SeganyEnabled(self):
"""
if not self.use_segment_anything:
self.actionSegment_anything.setEnabled(False)
self.actionSegment_anything_box.setEnabled(False)
return

results = self.seganythread.results_dict.get(self.current_index, {})
Expand All @@ -560,9 +562,11 @@ def SeganyEnabled(self):
self.segany.predictor_with_point_prompt._is_image_set = True

self.actionSegment_anything.setEnabled(True)
self.actionSegment_anything_box.setEnabled(True)
else:
self.segany.predictor_with_point_prompt.reset_image()
self.actionSegment_anything.setEnabled(False)
self.actionSegment_anything_box.setEnabled(False)

def seg_video_start(self, max_frame_num_to_track=None):
if self.current_index == None:
Expand Down Expand Up @@ -1161,6 +1165,7 @@ def change_bit_map_to_semantic(self):
self.annos_dock_widget.listWidget.setEnabled(False)
self.annos_dock_widget.checkBox_visible.setEnabled(False)
self.actionSegment_anything.setEnabled(False)
self.actionSegment_anything_box.setEnabled(False)
self.actionVideo_segment.setEnabled(False)
self.actionVideo_segment_once.setEnabled(False)
self.actionVideo_segment_five_times.setEnabled(False)
Expand Down Expand Up @@ -1189,6 +1194,7 @@ def change_bit_map_to_instance(self):
self.annos_dock_widget.listWidget.setEnabled(False)
self.annos_dock_widget.checkBox_visible.setEnabled(False)
self.actionSegment_anything.setEnabled(False)
self.actionSegment_anything_box.setEnabled(False)
self.actionVideo_segment.setEnabled(False)
self.actionVideo_segment_once.setEnabled(False)
self.actionVideo_segment_five_times.setEnabled(False)
Expand Down Expand Up @@ -1425,6 +1431,7 @@ def init_connect(self):
self.actionExit.triggered.connect(self.exit)

self.actionSegment_anything.triggered.connect(self.scene.start_segment_anything)
self.actionSegment_anything_box.triggered.connect(self.scene.start_segment_anything_box)
self.actionPolygon.triggered.connect(self.scene.start_draw_polygon)
self.actionCancel.triggered.connect(self.scene.cancel_draw)
self.actionBackspace.triggered.connect(self.scene.backspace)
Expand Down Expand Up @@ -1471,6 +1478,7 @@ def reset_action(self):
self.actionPrev.setEnabled(False)
self.actionNext.setEnabled(False)
self.actionSegment_anything.setEnabled(False)
self.actionSegment_anything_box.setEnabled(False)
self.actionPolygon.setEnabled(False)
self.actionVideo_segment.setEnabled(False)
self.actionVideo_segment_once.setEnabled(False)
Expand Down
Loading

0 comments on commit c3a2911

Please sign in to comment.