Skip to content

Commit

Permalink
Add multipoint function
Browse files Browse the repository at this point in the history
  • Loading branch information
Rub21 committed Nov 9, 2023
1 parent 87f0fa7 commit 4285845
Showing 1 changed file with 53 additions and 1 deletion.
54 changes: 53 additions & 1 deletion handler_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def preprocess(self, data):
"""
start = time()
row = data[0]

self.payload = row.get("data") or row.get("body")
if isinstance(self.payload, dict) and "image_embeddings" in self.payload \
and "image_shape" in self.payload \
Expand Down Expand Up @@ -102,6 +103,44 @@ def inference_single_point(self, image_embeddings):
print("XXXXXXX ambiguous mask proposals shape for single point", masks.shape) #(512,512) for single point mask
return masks[0], scores[0]

def inference_multi_point(self, image_embeddings):
"""
Internal inference methods for multipoints
:param predictor: transformed model input data
:return: list of inference output in NDArray
"""
start = time()
resizer = ResizeLongestSide(1024)
payload_input_point=self.payload['input_point'] # e.g: [[500, 375], [1125, 625]]
payload_input_label=self.payload['input_label'] # e.g: [1, 1]
payload_image_size=self.payload.get("image_shape")

input_point = np.array(payload_input_point)
input_label = np.array(payload_input_label)
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)

onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32)
# onnx_coord = image_embeddings.transform.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32)
onnx_coord = resizer.apply_coords(onnx_coord, payload_image_size).astype(np.float32)
onnx_has_mask_input = np.ones(1, dtype=np.float32)

ort_inputs = {
"image_embeddings": image_embeddings,
"point_coords": onnx_coord,
"point_labels": onnx_label,
"mask_input": onnx_mask_input,
"has_mask_input": onnx_has_mask_input,
"orig_im_size": np.array(payload_image_size, dtype=np.float32)
}

masks, scores, low_res_logits = self.ort_session.run(None, ort_inputs)
masks = masks > .5 # TODO make this configurable
print("XXXXX Inference time: ", time()-start)
print("XXXXXXX ambiguous mask proposals shape for multi point", masks.shape) #(512,512) for single point mask
return masks[0], scores[0]


def mask_to_geojson(self, mask, scores):
transform = rasterio.transform.from_bounds(*self.payload.get("bbox"), mask.shape[1], mask.shape[0])
# A list to store all features
Expand Down Expand Up @@ -144,7 +183,19 @@ def handle(self, data, context):
"""
model_input = self.preprocess(data)
# (N,512,512) mask input for single point, ambiguous proposals, highest conf not always best
masks, scores = self.inference_single_point(model_input)
payload_input_point = self.payload['input_point']
# payload_input_label = self.payload['input_label']
# # Check multipoint or single point request
# multipoint=[[771, 381],[771, 450]
# singlepoint=[771, 381]
if isinstance(payload_input_point[0], list):
print("#"*20 + " Multi Point " + "#"*20 )
masks, scores = self.inference_multi_point(model_input)
else:
print("#"*20 + " Single Point " + "#"*20 )
masks, scores = self.inference_single_point(model_input)

# Convert to geojson
if self.payload.get("crs") is not None and self.payload.get("bbox") is not None:
geojsons = []
for mask in masks:# need to clean this up and apply conversion to each ambiguous mask
Expand All @@ -156,6 +207,7 @@ def handle(self, data, context):
masks = self.postprocess(masks)
return [{"status": "success", "masks": masks, "confidence_scores": [np_to_py_type(score) for score in scores]}]


def np_to_py_type(o):
if isinstance(o, np.generic):
return o.item()
Expand Down

0 comments on commit 4285845

Please sign in to comment.