From c573187f31818bb17288d6b0278698133b737f09 Mon Sep 17 00:00:00 2001 From: LynnL4 Date: Wed, 10 May 2023 09:29:48 +0800 Subject: [PATCH] fomo: add softmax in tenosr mode --- edgelab/models/detectors/fomo.py | 22 ++++++++++++++++++++++ tools/torch2tflite.py | 7 +++++-- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/edgelab/models/detectors/fomo.py b/edgelab/models/detectors/fomo.py index 236abe7e..de548fcd 100644 --- a/edgelab/models/detectors/fomo.py +++ b/edgelab/models/detectors/fomo.py @@ -1,3 +1,4 @@ +import torch from typing import Optional, Dict from mmdet.models.detectors.single_stage import SingleStageDetector from edgelab.registry import MODELS @@ -17,3 +18,24 @@ def __init__(self, init_cfg: Optional[Dict] = None): super().__init__(backbone, neck, head, train_cfg, test_cfg, data_preprocessor, init_cfg) + + + def _forward( + self, + batch_inputs, + batch_data_samples): + """Network forward process. Usually includes backbone, neck and head + forward without any post-processing. + + Args: + batch_inputs (Tensor): Inputs with shape (N, C, H, W). + batch_data_samples (list[:obj:`DetDataSample`]): Each item contains + the meta information of each image and corresponding + annotations. + + Returns: + tuple[list]: A tuple of features from ``bbox_head`` forward. + """ + x = self.extract_feat(batch_inputs) + results = self.bbox_head.forward(x) + return torch.softmax(results[0], dim=1) \ No newline at end of file diff --git a/tools/torch2tflite.py b/tools/torch2tflite.py index 855830b8..45b654fc 100644 --- a/tools/torch2tflite.py +++ b/tools/torch2tflite.py @@ -5,6 +5,7 @@ import torch import numpy as np from copy import deepcopy +from functools import partial import edgelab.models import edgelab.datasets import edgelab.evaluation @@ -159,6 +160,10 @@ def export_tflite(args, model, context: DLContext): model: The model to be exported context (DLContext): The dataset context object """ + model.cpu().eval() + + #model.forward = partial(model.forward, mode='tensor') + dummy_input = torch.randn(1, *args.shape) if args.type == 'int8' or type == 'uint8': with model_tracer(): @@ -184,8 +189,6 @@ def export_tflite(args, model, context: DLContext): else: with torch.no_grad(): - model.cpu() - model.eval() torch.backends.quantized.engine = 'qnnpack' converter = TFLiteConverter( model, dummy_input, tflite_path=args.tflite_file)