From af20f5392c1da18fabdbfff58b1d90e8bcadec0d Mon Sep 17 00:00:00 2001 From: CoreyEWood Date: Tue, 8 Oct 2024 11:34:31 -0700 Subject: [PATCH] Fix count result parsing (#105) The format of roi response is a bit different than what was expected. --------- Co-authored-by: Auto-format Bot --- app/core/edge_inference.py | 6 +- .../test_parse_inference_response.py | 72 ++++++++++--------- 2 files changed, 43 insertions(+), 35 deletions(-) diff --git a/app/core/edge_inference.py b/app/core/edge_inference.py index 963086f1..d8321004 100644 --- a/app/core/edge_inference.py +++ b/app/core/edge_inference.py @@ -85,12 +85,12 @@ def parse_inference_response(response: dict) -> dict: text: str | None = None # Attempt to extract rois / text if secondary_predictions is not None: - roi_predictions: list[list[dict]] | None = secondary_predictions.get("roi_predictions", None) + roi_predictions: dict[str, list[list[dict]]] | None = secondary_predictions.get("roi_predictions", None) text_predictions: list[str] | None = secondary_predictions.get("text_predictions", None) if roi_predictions is not None: - rois = roi_predictions[0] + rois = roi_predictions["rois"][0] for i, roi in enumerate(rois): - geometry = rois[i]["geometry"] + geometry = roi["geometry"] # TODO add validation to calculate x and y automatically x = 0.5 * (geometry["left"] + geometry["right"]) y = 0.5 * (geometry["top"] + geometry["bottom"]) diff --git a/test/edge_inference/test_parse_inference_response.py b/test/edge_inference/test_parse_inference_response.py index 1f1109af..13b420cb 100644 --- a/test/edge_inference/test_parse_inference_response.py +++ b/test/edge_inference/test_parse_inference_response.py @@ -22,12 +22,14 @@ def mock_count_response(): }, "predictions": None, "secondary_predictions": { - "roi_predictions": [[{ - "label": "bird", - "geometry": {"left": 0.1, "top": 0.2, "right": 0.3, "bottom": 0.4, "version": "2.0"}, - "score": 0.9, - "version": "2.0", - }]], + "roi_predictions": { + "rois": [[{ + "label": "bird", + "geometry": {"left": 0.1, "top": 0.2, "right": 0.3, "bottom": 0.4, "version": "2.0"}, + "score": 0.9, + "version": "2.0", + }]] + }, "text_predictions": ["This is a bird."], }, } @@ -39,20 +41,22 @@ def mock_binary_with_rois_response(): "multi_predictions": None, "predictions": {"confidences": [0.54], "labels": [0], "probabilities": [0.45], "scores": [-2.94]}, "secondary_predictions": { - "roi_predictions": [[ - { - "label": "cat", - "geometry": {"left": 0.1, "top": 0.2, "right": 0.3, "bottom": 0.4, "version": "2.0"}, - "score": 0.8, - "version": "2.0", - }, - { - "label": "cat", - "geometry": {"left": 0.6, "top": 0.7, "right": 0.8, "bottom": 0.9, "version": "2.0"}, - "score": 0.7, - "version": "2.0", - }, - ]], + "roi_predictions": { + "rois": [[ + { + "label": "cat", + "geometry": {"left": 0.1, "top": 0.2, "right": 0.3, "bottom": 0.4, "version": "2.0"}, + "score": 0.8, + "version": "2.0", + }, + { + "label": "cat", + "geometry": {"left": 0.6, "top": 0.7, "right": 0.8, "bottom": 0.9, "version": "2.0"}, + "score": 0.7, + "version": "2.0", + }, + ]] + }, "text_predictions": None, }, } @@ -79,12 +83,14 @@ def mock_invalid_predictions_response(): }, "predictions": {"confidences": [0.54], "labels": [0], "probabilities": [0.45], "scores": [-2.94]}, "secondary_predictions": { - "roi_predictions": [[{ - "label": "bird", - "geometry": {"left": 0.1, "top": 0.2, "right": 0.3, "bottom": 0.4, "version": "2.0"}, - "score": 0.9, - "version": "2.0", - }]], + "roi_predictions": { + "rois": [[{ + "label": "bird", + "geometry": {"left": 0.1, "top": 0.2, "right": 0.3, "bottom": 0.4, "version": "2.0"}, + "score": 0.9, + "version": "2.0", + }]] + }, "text_predictions": None, }, } @@ -96,12 +102,14 @@ def mock_invalid_predictions_missing_response(): "multi_predictions": None, "predictions": None, "secondary_predictions": { - "roi_predictions": [[{ - "label": "bird", - "geometry": {"left": 0.1, "top": 0.2, "right": 0.3, "bottom": 0.4, "version": "2.0"}, - "score": 0.9, - "version": "2.0", - }]], + "roi_predictions": { + "rois": [[{ + "label": "bird", + "geometry": {"left": 0.1, "top": 0.2, "right": 0.3, "bottom": 0.4, "version": "2.0"}, + "score": 0.9, + "version": "2.0", + }]] + }, "text_predictions": None, }, }