-
Notifications
You must be signed in to change notification settings - Fork 0
/
FastSAM_inference.py
86 lines (74 loc) · 2.75 KB
/
FastSAM_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
from PIL import Image
import sys
sys.path.insert(0,"/home/skn/arpl/Projects/MOTfastSAM/FastSAM")
from fastsam import FastSAM, FastSAMPrompt
from utils.tools import convert_box_xywh_to_xyxy
class get_fSAM():
def __init__(self):
# FastSAM Arguments
self.model_path = "FastSAM/weights/FastSAM.pt"
self.output = "resources/output/"
self.imgsz = 800
self.iou = 0.9 # iou threshold for filtering the annotations
self.text_prompt = None#"human"
self.conf = 0.4 # object confidence threshold
self.randomcolor = True
self.point_prompt = [[0,0]]
self.point_label = [0]
self.box_prompt = [[0,0,0,0]]
self.better_quality = False
self.retina = True
self.withContours = False
self.output_frame = 0
self.device = torch.device(
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
print("device: ", self.device)
def infer(self, array, imsz):
# load model
model = FastSAM(self.model_path)
point_prompt = self.point_prompt
box_prompt = convert_box_xywh_to_xyxy(self.box_prompt)
point_label = self.point_label
everything_results = model(
array,
device=self.device,
retina_masks=self.retina,
conf=self.conf,
iou=self.iou
)
bboxes = None
points = None
point_label = None
prompt_process = FastSAMPrompt(array, everything_results, device=self.device)
self.output_frame +=1
if box_prompt[0][2] != 0 and box_prompt[0][3] != 0:
print("Box Prompt is ", self.box_prompt_prompt)
ann = prompt_process.box_prompt(bboxes=box_prompt)
bboxes = box_prompt
elif self.text_prompt != None:
print("Text Prompt is ", self.text_prompt)
ann = prompt_process.text_prompt(text=self.text_prompt)
elif point_prompt[0] != [0, 0]:
print("Point Prompt is ", self.point_prompt)
ann = prompt_process.point_prompt(
points=point_prompt, pointlabel=point_label
)
points = point_prompt
point_label = point_label
else:
ann = prompt_process.everything_prompt()
output = prompt_process.plot(
annotations=ann,
output_path=self.output+str(self.output_frame)+".jpg",
bboxes = bboxes,
points = points,
point_label = point_label,
withContours=self.withContours,
better_quality=self.better_quality,)
return output