Skip to content

Commit

Permalink
#368 Update inference interface (#371)
Browse files Browse the repository at this point in the history
  • Loading branch information
Only-bottle authored Aug 29, 2024
1 parent 3fe7f64 commit 68a9ae5
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 55 deletions.
6 changes: 3 additions & 3 deletions examples/inferencer/custom_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def postprocessor():

netspresso = NetsPresso(email=EMAIL, password=PASSWORD)

inferencer = netspresso.custom_inferencer()
input_model_path = "YOUR_MODEL_PATH"
inferencer = netspresso.custom_inferencer(input_model_path=input_model_path)

image_path = "YOUR_IMAGE_PATH"
colored_img = cv2.imread(image_path)
Expand All @@ -31,8 +32,7 @@ def postprocessor():
np.save(dataset_path, img)

# Inference
input_model_path = "YOUR_MODEL_PATH"
outputs = inferencer.inference(input_model_path, dataset_path)
outputs = inferencer.inference(dataset_path)

# Postprocess
pred = postprocessor()
6 changes: 3 additions & 3 deletions examples/inferencer/np_inferencer/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@

# 1. Declare inferencer
config_path = trained_result.runtime
inferencer = netspresso.np_inferencer(config_path=config_path)
input_model_path = trained_result.best_onnx_model_path
inferencer = netspresso.np_inferencer(config_path=config_path, input_model_path=input_model_path)

# 2. Inference image
valid_imgs = glob("/root/datasets/cifar100/images/valid/*.png")
for valid_img in valid_imgs[:100]:
input_model_path = trained_result.best_onnx_model_path
save_path = f"{Path(input_model_path).parent}/inference_results/{Path(valid_img).name}"
outputs = inferencer.inference(input_model_path, valid_img, save_path)
outputs = inferencer.inference(image_path=valid_img, save_path=save_path)
6 changes: 3 additions & 3 deletions examples/inferencer/np_inferencer/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@

# 1. Declare inferencer
config_path = trained_result.runtime
inferencer = netspresso.np_inferencer(config_path=config_path)
input_model_path = trained_result.best_onnx_model_path
inferencer = netspresso.np_inferencer(config_path=config_path, input_model_path=input_model_path)

# 2. Inference image
valid_imgs = glob("/root/datasets/traffic-sign/images/valid/*.jpg")
for valid_img in valid_imgs:
input_model_path = trained_result.best_onnx_model_path
save_path = f"{Path(input_model_path).parent}/inference_results/{Path(valid_img).name}"
outputs = inferencer.inference(input_model_path, valid_img, save_path)
outputs = inferencer.inference(image_path=valid_img, save_path=save_path)
6 changes: 3 additions & 3 deletions examples/inferencer/np_inferencer/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@

# 1. Declare inferencer
config_path = trained_result.runtime
inferencer = netspresso.np_inferencer(config_path=config_path)
input_model_path = trained_result.best_onnx_model_path
inferencer = netspresso.np_inferencer(config_path=config_path, input_model_path=input_model_path)

# 2. Inference image
valid_imgs = glob("/root/datasets/voc2012_seg/images/valid/*.jpg")
for valid_img in valid_imgs[:100]:
input_model_path = trained_result.best_onnx_model_path
save_path = Path(input_model_path).parent / "inference_results" / Path(valid_img).name
outputs = inferencer.inference(input_model_path, valid_img, save_path)
outputs = inferencer.inference(image_path=valid_img, save_path=save_path)
72 changes: 33 additions & 39 deletions netspresso/inferencer/inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,35 @@


class BaseInferencer:
def __init__(self) -> None:
pass
def __init__(self, input_model_path) -> None:
self.input_model_path = input_model_path
self.inferencer = self._create_inferencer(input_model_path)

self.suffix = Path(input_model_path).suffix
self.runtime = Runtime.get_runtime_by_suffix(self.suffix)

def _inference(self, dataset_path: str):
inference_results = self.inferencer.inference(dataset_path)

return inference_results

def _create_inferencer(self, input_model_path: str):
self.inferencer = InferenceService(model_file_path=input_model_path)
inferencer = InferenceService(model_file_path=input_model_path)

return inferencer

def transpose_input(self, runtime, input):
if runtime == Runtime.ONNX:
def transpose_input(self, input):
if self.runtime == Runtime.ONNX:
input = input.transpose(0, 3, 1, 2)
elif runtime == Runtime.TFLITE:
elif self.runtime == Runtime.TFLITE:
pass

return input

def transpose_outputs(self, runtime, outputs):
if runtime == Runtime.ONNX:
def transpose_outputs(self, outputs):
if self.runtime == Runtime.ONNX:
pass
elif runtime == Runtime.TFLITE:
elif self.runtime == Runtime.TFLITE:
outputs = [np.transpose(index, (0, 3, 1, 2)) for index in outputs] # (b, h, w, c) -> (b, c, h, w)

return outputs
Expand All @@ -68,8 +74,8 @@ def save_image(self, image, save_path):


class NPInferencer(BaseInferencer):
def __init__(self, config_path) -> None:
super().__init__()
def __init__(self, config_path: str, input_model_path: str) -> None:
super().__init__(input_model_path)
self.config_path = config_path
self.runtime_config = OmegaConf.load(config_path).runtime
self.build_preprocessor()
Expand Down Expand Up @@ -124,48 +130,42 @@ def dequantize_outputs(self, results):

return results

def preprocess_input(self, runtime: Runtime, inputs):
if runtime == Runtime.ONNX:
input_data = self.transpose_input(runtime=runtime, input=inputs)
elif runtime == Runtime.TFLITE:
def preprocess_input(self, inputs):
if self.runtime == Runtime.ONNX:
input_data = self.transpose_input(input=inputs)
elif self.runtime == Runtime.TFLITE:
input_data = self.quantize_input(inputs)

return input_data

def postprocess_output(self, runtime: Runtime, outputs):
if runtime == Runtime.ONNX:
def postprocess_output(self, outputs):
if self.runtime == Runtime.ONNX:
pass
elif runtime == Runtime.TFLITE:
elif self.runtime == Runtime.TFLITE:
outputs = self.dequantize_outputs(outputs)

outputs = list(outputs.values())
outputs = self.transpose_outputs(runtime, outputs)
outputs = self.transpose_outputs(outputs)

return outputs

def inference(self, input_model_path: str, image_path: str, save_path: str):
suffix = Path(input_model_path).suffix
runtime = Runtime.get_runtime_by_suffix(suffix)

# Create inferencer
self._create_inferencer(input_model_path)

def inference(self, image_path: str, save_path: str):
# Load image
img = cv2.imread(image_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_draw = img.copy()

# Preprocess image
img = self.preprocessor(img)
input_data = self.preprocess_input(runtime=runtime, inputs=img)
input_data = self.preprocess_input(inputs=img)
dataset_path = self.save_numpy_data(data=input_data)

# Inference data
inference_results = self._inference(dataset_path)
Path(dataset_path).unlink()

# Postprocess outputs
outputs = self.postprocess_output(runtime, inference_results)
outputs = self.postprocess_output(outputs=inference_results)

model_input_shape = None

Expand All @@ -187,16 +187,10 @@ def inference(self, input_model_path: str, image_path: str, save_path: str):


class CustomInferencer(BaseInferencer):
def __init__(self) -> None:
super().__init__()

def inference(self, input_model_path: str, dataset_path: str):
suffix = Path(input_model_path).suffix
runtime = Runtime.get_runtime_by_suffix(suffix)
def __init__(self, input_model_path: str) -> None:
super().__init__(input_model_path)

inference_results = self._inference(input_model_path, dataset_path)

outputs = list(inference_results.values())
outputs = self.transpose_outputs(runtime, outputs)
def inference(self, dataset_path: str):
inference_results = self._inference(dataset_path)

return outputs
return inference_results
8 changes: 4 additions & 4 deletions netspresso/netspresso.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,22 +75,22 @@ def benchmarker_v2(self) -> BenchmarkerV2:
"""
return BenchmarkerV2(token_handler=self.token_handler, user_info=self.user_info)

def np_inferencer(self, config_path) -> NPInferencer:
def np_inferencer(self, config_path: str, input_model_path: str) -> NPInferencer:
"""Initialize and return a Inferencer instance.
Returns:
Inferencer: Initialized Inferencer instance.
"""

return NPInferencer(config_path=config_path)
return NPInferencer(config_path=config_path, input_model_path=input_model_path)

def custom_inferencer(self) -> CustomInferencer:
def custom_inferencer(self, input_model_path: str) -> CustomInferencer:
"""Initialize and return a Inferencer instance.
Returns:
Inferencer: Initialized Inferencer instance.
"""
return CustomInferencer()
return CustomInferencer(input_model_path=input_model_path)


class TAO:
Expand Down

0 comments on commit 68a9ae5

Please sign in to comment.