Skip to content

Commit

Permalink
fomo: add softmax in tenosr mode
Browse files Browse the repository at this point in the history
  • Loading branch information
LynnL4 committed May 10, 2023
1 parent ed16cb4 commit c573187
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
22 changes: 22 additions & 0 deletions edgelab/models/detectors/fomo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
from typing import Optional, Dict
from mmdet.models.detectors.single_stage import SingleStageDetector
from edgelab.registry import MODELS
Expand All @@ -17,3 +18,24 @@ def __init__(self,
init_cfg: Optional[Dict] = None):
super().__init__(backbone, neck, head, train_cfg, test_cfg,
data_preprocessor, init_cfg)


def _forward(
self,
batch_inputs,
batch_data_samples):
"""Network forward process. Usually includes backbone, neck and head
forward without any post-processing.
Args:
batch_inputs (Tensor): Inputs with shape (N, C, H, W).
batch_data_samples (list[:obj:`DetDataSample`]): Each item contains
the meta information of each image and corresponding
annotations.
Returns:
tuple[list]: A tuple of features from ``bbox_head`` forward.
"""
x = self.extract_feat(batch_inputs)
results = self.bbox_head.forward(x)
return torch.softmax(results[0], dim=1)
7 changes: 5 additions & 2 deletions tools/torch2tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import numpy as np
from copy import deepcopy
from functools import partial
import edgelab.models
import edgelab.datasets
import edgelab.evaluation
Expand Down Expand Up @@ -159,6 +160,10 @@ def export_tflite(args, model, context: DLContext):
model: The model to be exported
context (DLContext): The dataset context object
"""
model.cpu().eval()

#model.forward = partial(model.forward, mode='tensor')

dummy_input = torch.randn(1, *args.shape)
if args.type == 'int8' or type == 'uint8':
with model_tracer():
Expand All @@ -184,8 +189,6 @@ def export_tflite(args, model, context: DLContext):

else:
with torch.no_grad():
model.cpu()
model.eval()
torch.backends.quantized.engine = 'qnnpack'
converter = TFLiteConverter(
model, dummy_input, tflite_path=args.tflite_file)
Expand Down

0 comments on commit c573187

Please sign in to comment.