-
Notifications
You must be signed in to change notification settings - Fork 7
/
run_image.py
474 lines (403 loc) · 17.2 KB
/
run_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# ---------------------------------------------------------------------------------------------------------------------
# %% Imports
import argparse
import os.path as osp
from time import perf_counter
import torch
import cv2
import numpy as np
from lib.make_sam import make_sam_from_state_dict
from lib.demo_helpers.ui.window import DisplayWindow, KEY
from lib.demo_helpers.ui.layout import HStack, VStack
from lib.demo_helpers.ui.buttons import ToggleButton, ImmediateButton
from lib.demo_helpers.ui.sliders import HSlider
from lib.demo_helpers.ui.static import StaticMessageBar
from lib.demo_helpers.shared_ui_layout import PromptUIControl, PromptUI, ReusableBaseImage
from lib.demo_helpers.crop_ui import run_crop_ui
from lib.demo_helpers.video_frame_select_ui import run_video_frame_select_ui
from lib.demo_helpers.contours import get_contours_from_mask
from lib.demo_helpers.mask_postprocessing import MaskPostProcessor
from lib.demo_helpers.history_keeper import HistoryKeeper
from lib.demo_helpers.loading import ask_for_path_if_missing, ask_for_model_path_if_missing, load_init_prompts
from lib.demo_helpers.saving import save_image_segmentation, get_save_name, make_prompt_save_data
from lib.demo_helpers.misc import (
get_default_device_string,
make_device_config,
get_total_cuda_vram_usage_mb,
)
# ---------------------------------------------------------------------------------------------------------------------
# %% Set up script args
# Set argparse defaults
default_device = get_default_device_string()
default_image_path = None
default_model_path = None
default_prompts_path = None
default_mask_hint_path = None
default_display_size = 900
default_base_size = 1024
default_window_size = 16
default_show_iou_preds = False
# Define script arguments
parser = argparse.ArgumentParser(description="Script used to run Segment-Anything (SAM) on a single image")
parser.add_argument("-i", "--image_path", default=default_image_path, help="Path to input image")
parser.add_argument("-m", "--model_path", default=default_model_path, type=str, help="Path to SAM model weights")
parser.add_argument(
"-p",
"--prompts_path",
default=default_prompts_path,
type=str,
help="Path to a json file containing initial prompts to use on start-up (see saved json results for formatting)",
)
parser.add_argument(
"--mask_path",
default=default_mask_hint_path,
type=str,
help="Path to a mask image, which will be used as a prompt for segmentation",
)
parser.add_argument(
"-s",
"--display_size",
default=default_display_size,
type=int,
help=f"Controls size of displayed results (default: {default_display_size})",
)
parser.add_argument(
"-d",
"--device",
default=default_device,
type=str,
help=f"Device to use when running model, such as 'cpu' (default: {default_device})",
)
parser.add_argument(
"-f32",
"--use_float32",
default=False,
action="store_true",
help="Use 32-bit floating point model weights. Note: this doubles VRAM usage",
)
parser.add_argument(
"-ar",
"--use_aspect_ratio",
default=False,
action="store_true",
help="Process the image at it's original aspect ratio",
)
parser.add_argument(
"-b",
"--base_size_px",
default=default_base_size,
type=int,
help=f"Override base model size (default {default_base_size})",
)
parser.add_argument(
"-q",
"--quality_estimate",
default=default_show_iou_preds,
action="store_false" if default_show_iou_preds else "store_true",
help="Hide mask quality estimates" if default_show_iou_preds else "Show mask quality estimates",
)
parser.add_argument(
"--hide_info",
default=False,
action="store_true",
help="Hide text info elements from UI",
)
parser.add_argument(
"--enable_promptless_masks",
default=False,
action="store_true",
help="If set, the model will generate mask predictions even when no prompts are given",
)
parser.add_argument(
"--crop",
default=False,
action="store_true",
help="If set, a cropping UI will appear on start-up to allow for the image to be cropped prior to processing",
)
# For convenience
args = parser.parse_args()
arg_image_path = args.image_path
arg_model_path = args.model_path
init_prompts_path = args.prompts_path
mask_hint_path = args.mask_path
display_size_px = args.display_size
device_str = args.device
use_float32 = args.use_float32
use_square_sizing = not args.use_aspect_ratio
imgenc_base_size = args.base_size_px
show_iou_preds = args.quality_estimate
show_info = not args.hide_info
disable_promptless_masks = not args.enable_promptless_masks
enable_crop_ui = args.crop
# Set up device config
device_config_dict = make_device_config(device_str, use_float32)
# Create history to re-use selected inputs
history = HistoryKeeper()
_, history_imgpath = history.read("image_path")
_, history_modelpath = history.read("model_path")
# Get pathing to resources, if not provided already
image_path = ask_for_path_if_missing(arg_image_path, "image", history_imgpath)
model_path = ask_for_model_path_if_missing(__file__, arg_model_path, history_modelpath)
# Store history for use on reload
history.store(image_path=image_path, model_path=model_path)
# ---------------------------------------------------------------------------------------------------------------------
# %% Load resources
# Get the model name, for reporting
model_name = osp.basename(model_path)
print("", "Loading model weights...", f" @ {model_path}", sep="\n", flush=True)
model_config_dict, sammodel = make_sam_from_state_dict(model_path)
sammodel.to(**device_config_dict)
# Load image and get shaping info for providing display
loaded_image_bgr = cv2.imread(image_path)
if loaded_image_bgr is None:
ok_video, loaded_image_bgr = run_video_frame_select_ui(image_path)
if not ok_video:
print("", "Unable to load image!", f" @ {image_path}", sep="\n", flush=True)
raise FileNotFoundError(osp.basename(image_path))
# Crop input image if needed
input_image_bgr = loaded_image_bgr
yx_crop_slice = None
if enable_crop_ui:
print("", "Cropping enabled: Adjust box to select image area for further processing", sep="\n", flush=True)
_, history_crop_tlbr = history.read("crop_tlbr_norm")
yx_crop_slice, crop_tlbr_norm = run_crop_ui(loaded_image_bgr, display_size_px, history_crop_tlbr)
input_image_bgr = loaded_image_bgr[yx_crop_slice]
history.store(crop_tlbr_norm=crop_tlbr_norm)
# Try loading the given mask hint
mask_hint_img = None
if mask_hint_path is not None:
assert osp.exists(mask_hint_path), f"Invalid mask hint path: {mask_hint_path}"
mask_hint_img = cv2.imread(mask_hint_path)
assert mask_hint_img is not None, f"Error loading mask hint image: {mask_hint_path}"
use_mask_hint = mask_hint_img is not None
# ---------------------------------------------------------------------------------------------------------------------
# %% Run image encoder
# Run Model
print("", "Encoding image data...", sep="\n", flush=True)
t1 = perf_counter()
encoded_img, token_hw, preencode_img_hw = sammodel.encode_image(input_image_bgr, imgenc_base_size, use_square_sizing)
if torch.cuda.is_available():
torch.cuda.synchronize()
t2 = perf_counter()
time_taken_ms = round(1000 * (t2 - t1))
print(f" -> Took {time_taken_ms} ms", flush=True)
# Run model without prompts as sanity check. Also gives initial result values
box_tlbr_norm_list, fg_xy_norm_list, bg_xy_norm_list = [], [], []
encoded_prompts = sammodel.encode_prompts(box_tlbr_norm_list, fg_xy_norm_list, bg_xy_norm_list)
mask_preds, iou_preds = sammodel.generate_masks(
encoded_img, encoded_prompts, blank_promptless_output=disable_promptless_masks
)
# Set up mask hint to match image encoding, if needed
mask_hint = None
if use_mask_hint:
pred_h, pred_w = mask_preds.shape[2:]
mask_hint_img_1ch = cv2.cvtColor(mask_hint_img, cv2.COLOR_BGR2GRAY)
mask_hint_img_1ch = cv2.resize(mask_hint_img_1ch, dsize=(pred_w, pred_h))
mask_hint = torch.from_numpy(mask_hint_img_1ch).squeeze().unsqueeze(0)
mask_hint = ((mask_hint / max(mask_hint.max(), 1.0)) - 0.5) * 20.0
mask_hint = mask_hint.to(mask_preds)
mask_preds, iou_preds = sammodel.generate_masks(encoded_img, encoded_prompts, mask_hint)
# Provide some feedback about how the model is running
model_device = device_config_dict["device"]
model_dtype = str(device_config_dict["dtype"]).split(".")[-1]
image_hw_str = f"{preencode_img_hw[0]} x {preencode_img_hw[1]}"
token_hw_str = f"{token_hw[0]} x {token_hw[1]}"
print(
"",
f"Config ({model_name}):",
f" Device: {model_device} ({model_dtype})",
f" Resolution HW: {image_hw_str}",
f" Tokens HW: {token_hw_str}",
sep="\n",
flush=True,
)
# Provide memory usage feedback, if using cuda GPU
if model_device == "cuda":
total_vram_mb = get_total_cuda_vram_usage_mb()
print(" VRAM:", total_vram_mb, "MB")
# ---------------------------------------------------------------------------------------------------------------------
# %% Set up the UI
# Set up shared UI elements & control logic
ui_elems = PromptUI(input_image_bgr, mask_preds)
uictrl = PromptUIControl(ui_elems)
# Set up message bars to communicate data info & controls
device_dtype_str = f"{model_device}/{model_dtype}"
header_msgbar = StaticMessageBar(model_name, f"{token_hw_str} tokens", device_dtype_str, space_equally=True)
footer_msgbar = StaticMessageBar(
"[p] Preview",
"[i] Invert",
"[tab] Contouring",
"[m] Mask hint" if use_mask_hint else "[arrows] Tools/Masks",
text_scale=0.35,
)
# Set up secondary button controls
mask_hint_btn, show_preview_btn, invert_mask_btn, large_mask_only_btn, pick_best_btn = ToggleButton.many(
"Mask Hint", "Preview", "Invert", "Largest Only", "Pick best", default_state=False, text_scale=0.5
)
mask_hint_btn.toggle(use_mask_hint)
large_mask_only_btn.toggle(True)
save_btn = ImmediateButton("Save", (60, 170, 20))
secondary_ctrls = HStack(
mask_hint_btn if use_mask_hint else None,
show_preview_btn,
invert_mask_btn,
large_mask_only_btn,
pick_best_btn,
save_btn,
)
# Set up slider controls
thresh_slider = HSlider("Mask Threshold", 0, -8.0, 8.0, 0.1, marker_steps=10)
rounding_slider = HSlider("Round contours", 0, -50, 50, 1, marker_steps=5)
padding_slider = HSlider("Pad contours", 0, -50, 50, 1, marker_steps=5)
simplify_slider = HSlider("Simplify contours", 0, 0, 10, 0.25, marker_steps=4)
# Set up full display layout
disp_layout = VStack(
header_msgbar if show_info else None,
ui_elems.layout,
secondary_ctrls,
thresh_slider,
simplify_slider,
rounding_slider,
padding_slider,
footer_msgbar if show_info else None,
).set_debug_name("DisplayLayout")
# Render out an image with a target size, to figure out which side we should limit when rendering
display_image = disp_layout.render(h=display_size_px, w=display_size_px)
render_side = "h" if display_image.shape[1] > display_image.shape[0] else "w"
render_limit_dict = {render_side: display_size_px}
min_display_size_px = disp_layout._rdr.limits.min_h if render_side == "h" else disp_layout._rdr.limits.min_w
# Load initial prompts, if provided
have_init_prompts, init_prompts_dict = load_init_prompts(init_prompts_path)
if have_init_prompts:
uictrl.load_initial_prompts(init_prompts_dict)
# ---------------------------------------------------------------------------------------------------------------------
# %% Window setup
# Set up display
cv2.destroyAllWindows()
window = DisplayWindow("Display - q to quit", display_fps=60).attach_mouse_callbacks(disp_layout)
window.move(200, 50)
# Change tools/masks on arrow keys
uictrl.attach_arrowkey_callbacks(window)
# Keypress for secondary controls
window.attach_keypress_callback("p", show_preview_btn.toggle)
window.attach_keypress_callback(KEY.TAB, large_mask_only_btn.toggle)
window.attach_keypress_callback("i", invert_mask_btn.toggle)
window.attach_keypress_callback("s", save_btn.click)
window.attach_keypress_callback("c", ui_elems.tools.clear.click)
# Add toggle for mask hinting if needed
if use_mask_hint:
window.attach_keypress_callback("m", mask_hint_btn.toggle)
# For clarity, some additional keypress codes
KEY_ZOOM_IN = ord("=")
KEY_ZOOM_OUT = ord("-")
# Set up helper objects for managing display/mask data
base_img_maker = ReusableBaseImage(input_image_bgr)
mask_postprocessor = MaskPostProcessor()
# Some feedback
print(
"",
"Use prompts to segment the image!",
"- Shift-click to add multiple points",
"- Right-click to remove points",
"- Press -/+ keys to change display sizing",
"- Press q or esc to close the window",
"",
sep="\n",
flush=True,
)
# *** Main display loop ***
try:
while True:
# Read prompt input data & selected mask
is_prompt_changed, (box_tlbr_norm_list, fg_xy_norm_list, bg_xy_norm_list) = uictrl.read_prompts()
is_mask_changed, mselect_idx, selected_mask_btn = ui_elems.masks_constraint.read()
# Read secondary controls
is_mhint_changed, enable_mask_hint = mask_hint_btn.read()
_, show_mask_preview = show_preview_btn.read()
is_invert_changed, use_inverted_mask = invert_mask_btn.read()
_, use_largest_contour = large_mask_only_btn.read()
_, use_best_mask = pick_best_btn.read()
# Read sliders
is_mthresh_changed, mthresh = thresh_slider.read()
_, msimplify = simplify_slider.read()
_, mrounding = rounding_slider.read()
_, mpadding = padding_slider.read()
# Update post-processor based on control values
mask_postprocessor.update(use_largest_contour, msimplify, mrounding, mpadding, use_inverted_mask)
# Only run the model when an input affecting the output has changed!
need_prompt_encode = is_prompt_changed or is_mhint_changed
if need_prompt_encode:
encoded_prompts = sammodel.encode_prompts(box_tlbr_norm_list, fg_xy_norm_list, bg_xy_norm_list)
mask_preds, iou_preds = sammodel.generate_masks(
encoded_img,
encoded_prompts,
mask_hint if enable_mask_hint else None,
blank_promptless_output=disable_promptless_masks,
)
if use_best_mask:
best_mask_idx = sammodel.get_best_mask_index(iou_preds)
ui_elems.masks_constraint.change_to(best_mask_idx)
# Update mask previews & selected mask for outlines
need_mask_update = any((need_prompt_encode, is_mthresh_changed, is_invert_changed, is_mask_changed))
if need_mask_update:
selected_mask_uint8 = uictrl.create_hires_mask_uint8(mask_preds, mselect_idx, preencode_img_hw, mthresh)
uictrl.update_mask_previews(mask_preds, mthresh, use_inverted_mask)
if show_iou_preds:
uictrl.draw_iou_predictions(iou_preds)
# Process contour data
final_mask_uint8 = selected_mask_uint8
ok_contours, mask_contours_norm = get_contours_from_mask(final_mask_uint8, normalize=True)
if ok_contours:
# If only 1 fg point prompt is given, use it to hint at selecting largest masks
point_hint = None
only_one_fg_pt = len(fg_xy_norm_list) == 1
no_box_prompt = len(box_tlbr_norm_list) == 0
if only_one_fg_pt and no_box_prompt:
point_hint = fg_xy_norm_list[0]
mask_contours_norm, final_mask_uint8 = mask_postprocessor(final_mask_uint8, mask_contours_norm, point_hint)
# Re-generate display image at required display size
# -> Not strictly needed, but can avoid constant re-sizing of base image (helpful for large images)
display_hw = ui_elems.image.get_render_hw()
disp_img = base_img_maker.regenerate(display_hw)
# Update the main display image in the UI
uictrl.update_main_display_image(disp_img, final_mask_uint8, mask_contours_norm, show_mask_preview)
# Render final output
display_image = disp_layout.render(**render_limit_dict)
req_break, keypress = window.show(display_image)
if req_break:
break
# Scale display size up when pressing +/- keys
if keypress == KEY_ZOOM_IN:
display_size_px = min(display_size_px + 50, 10000)
render_limit_dict = {render_side: display_size_px}
if keypress == KEY_ZOOM_OUT:
display_size_px = max(display_size_px - 50, min_display_size_px)
render_limit_dict = {render_side: display_size_px}
# Save data
if save_btn.read():
# Get additional data for saving
disp_image = ui_elems.display_block.rerender()
all_prompts_dict = make_prompt_save_data(box_tlbr_norm_list, fg_xy_norm_list, bg_xy_norm_list)
# Generate & save segmentation images!
save_folder, save_idx = get_save_name(image_path, "manual")
save_image_segmentation(
save_folder,
save_idx,
loaded_image_bgr,
disp_image,
selected_mask_uint8,
mask_contours_norm,
all_prompts_dict,
use_inverted_mask,
yx_crop_slice,
)
print(f"SAVED ({save_idx}):", save_folder)
pass
except KeyboardInterrupt:
print("", "Closed with Ctrl+C", sep="\n")
except Exception as err:
raise err
finally:
cv2.destroyAllWindows()