Skip to content

Commit

Permalink
Add an option to process multiple points and loop them as single point
Browse files Browse the repository at this point in the history
  • Loading branch information
Rub21 committed Nov 11, 2023
1 parent 3a7ee45 commit ac9b577
Showing 1 changed file with 27 additions and 5 deletions.
32 changes: 27 additions & 5 deletions handler_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,16 @@ def initialize(self, context):
self.initialized = True
print("XXXXX Initialization time: ", time()-start)

def inference_single_point(self, image_embeddings):
def inference_single_point(self, image_embeddings, input_point, input_label):
"""
Internal inference methods
:param model_input: transformed model input data
:return: list of inference output in NDArray
"""
start = time()
resizer = ResizeLongestSide(1024) # 1024 is the max for the export onnx example nb
onnx_coord = np.concatenate([np.array(self.payload['input_point'])[np.newaxis,:], np.array([[0.0, 0.0]])], axis=0)[None, :, :]
onnx_label = np.concatenate([np.array([self.payload['input_label']]), np.array([-1])], axis=0)[None, :].astype(np.float32)
onnx_coord = np.concatenate([np.array(input_point)[np.newaxis,:], np.array([[0.0, 0.0]])], axis=0)[None, :, :]
onnx_label = np.concatenate([np.array(input_label), np.array([-1])], axis=0)[None, :].astype(np.float32)
onnx_coord = resizer.apply_coords(onnx_coord, self.payload.get("image_shape")).astype(np.float32)
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.zeros(1, dtype=np.float32)
Expand Down Expand Up @@ -183,16 +183,38 @@ def handle(self, data, context):
:return: prediction output
"""
model_input = self.preprocess(data)
masks = None
scores = None
# (N,512,512) mask input for single point, ambiguous proposals, highest conf not always best
# Check decode type, which comes on the payload
########### Single point
if self.payload["decode_type"] == "single_point":
masks, scores = self.inference_single_point(model_input)
input_point = self.payload['input_point']
input_label = [self.payload['input_label']]
masks, scores = self.inference_single_point(model_input, input_point, input_label)
########### Multi point
elif self.payload["decode_type"] == "multi_point":
masks, scores = self.inference_multi_point(model_input)
########### Multi point, split in a single point
elif self.payload["decode_type"] == "multi_point_split":
all_scores = None
all_masks = None
for index,input_point_ in enumerate(self.payload['input_point']):
input_label_ = [self.payload['input_label'][index]]
if index == 0:
all_masks, all_scores = self.inference_single_point(model_input,input_point_, input_label_)
else:
masks, scores = self.inference_single_point(model_input,input_point_, input_label_)
all_masks = np.concatenate((all_masks, masks), axis=0)
all_scores = np.concatenate((all_scores, scores), axis=0)
masks = all_masks
scores = all_scores

# Convert to geojson
if self.payload.get("crs") is not None and self.payload.get("bbox") is not None:
geojsons = []
for mask in masks:# need to clean this up and apply conversion to each ambiguous mask
split_masks = np.array_split(masks, masks.shape[0], axis=0)
for mask in split_masks: # need to clean this up and apply conversion to each ambiguous mask
print(f"xxxxxxxxxShape mask: {mask.shape}")
multipolygon = self.mask_to_geojson(mask, scores)
geojsons.append(multipolygon)
Expand Down

0 comments on commit ac9b577

Please sign in to comment.