From b5c6d0eed2a384a715cc800548d06458daa2e774 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Sun, 18 Feb 2024 18:06:41 +0800 Subject: [PATCH] Refactor 2024 (#128) * WIP * ssd done * test faster rcnn * update readme * update readme * fix int8 * fix mask * update readme --- .pre-commit-config.yaml | 2 +- README.md | 116 +++--- demo/cpp/README.md | 9 +- demo/inference.py | 25 +- docker/Dockerfile | 20 +- docs/FAQ.md | 18 +- docs/getting_started.md | 68 ++-- mmdet2trt/__init__.py | 2 +- mmdet2trt/__main__.py | 139 ++++++++ mmdet2trt/apis/inference.py | 202 +++++------ mmdet2trt/converters/SAConv2d.py | 5 +- mmdet2trt/converters/__init__.py | 44 +-- mmdet2trt/converters/anchor_generator.py | 20 -- mmdet2trt/converters/bfp_forward.py | 3 +- mmdet2trt/converters/delta2bbox_custom.py | 4 +- mmdet2trt/converters/generalized_attention.py | 3 +- mmdet2trt/core/__init__.py | 2 - mmdet2trt/core/bbox/__init__.py | 5 - .../core/bbox/iou_calculators/__init__.py | 3 - mmdet2trt/mmdet2trt.py | 330 ++++++------------ mmdet2trt/models/__init__.py | 17 +- mmdet2trt/models/backbones/base_backbone.py | 21 +- mmdet2trt/models/builder.py | 7 +- .../models/dense_heads/anchor_free_head.py | 3 +- mmdet2trt/models/dense_heads/anchor_head.py | 25 +- mmdet2trt/models/dense_heads/atss_head.py | 9 +- .../models/dense_heads/cascade_rpn_head.py | 13 +- .../models/dense_heads/centripetal_head.py | 5 +- mmdet2trt/models/dense_heads/corner_head.py | 7 +- mmdet2trt/models/dense_heads/detr_head.py | 9 +- mmdet2trt/models/dense_heads/fcos_head.py | 9 +- mmdet2trt/models/dense_heads/fovea_head.py | 7 +- mmdet2trt/models/dense_heads/ga_rpn_head.py | 7 +- mmdet2trt/models/dense_heads/gfl_head.py | 13 +- .../models/dense_heads/guided_anchor_head.py | 15 +- mmdet2trt/models/dense_heads/paa_head.py | 11 +- .../models/dense_heads/reppoints_head.py | 13 +- mmdet2trt/models/dense_heads/rpn_head.py | 15 +- .../models/dense_heads/sabl_retina_head.py | 13 +- mmdet2trt/models/dense_heads/vfnet_head.py | 9 +- mmdet2trt/models/dense_heads/yolo_head.py | 15 +- mmdet2trt/models/dense_heads/yolox_head.py | 9 +- mmdet2trt/models/detectors/single_stage.py | 47 ++- mmdet2trt/models/detectors/two_stage.py | 28 +- .../models/{utils => layers}/__init__.py | 0 .../{utils => layers}/position_encoding.py | 5 +- mmdet2trt/models/necks/base_neck.py | 19 +- mmdet2trt/models/necks/hrfpn.py | 5 +- .../models/roi_heads/bbox_heads/bbox_head.py | 13 +- .../roi_heads/bbox_heads/double_bbox_head.py | 4 +- .../models/roi_heads/bbox_heads/sabl_head.py | 5 +- .../models/roi_heads/cascade_roi_head.py | 20 +- mmdet2trt/models/roi_heads/double_roi_head.py | 4 +- mmdet2trt/models/roi_heads/grid_roi_head.py | 13 +- mmdet2trt/models/roi_heads/htc_roi_head.py | 11 +- .../roi_heads/mask_heads/fcn_mask_head.py | 6 +- .../models/roi_heads/mask_heads/grid_head.py | 5 +- .../roi_heads/mask_heads/htc_mask_head.py | 6 +- .../roi_extractors/generic_roi_extractor.py | 7 +- .../deform_roi_pool_extractor.py | 10 +- .../single_level_roi_extractor.py | 9 +- .../models/roi_heads/standard_roi_head.py | 20 +- mmdet2trt/models/task_modules/__init__.py | 2 + .../task_modules/coders}/__init__.py | 0 .../coders}/bucketing_bbox_coder.py | 5 +- .../coders}/delta_xywh_bbox_coder.py | 5 +- .../task_modules/coders}/tblr_bbox_coder.py | 7 +- .../task_modules/coders}/transforms.py | 0 .../task_modules/coders}/yolo_bbox_coder.py | 5 +- .../prior_generators}/__init__.py | 0 .../prior_generators}/anchor_generator.py | 12 +- .../prior_generators}/point_generator.py | 10 +- mmdet2trt/structures/__init__.py | 0 mmdet2trt/structures/bbox/__init__.py | 7 + .../bbox/bbox_overlaps.py} | 0 .../{core => structures}/bbox/transforms.py | 3 +- setup.py | 2 +- tests/model_test.py | 29 +- tools/collect_env.py | 1 - tools/test.py | 6 +- 80 files changed, 780 insertions(+), 823 deletions(-) create mode 100644 mmdet2trt/__main__.py delete mode 100644 mmdet2trt/core/bbox/__init__.py delete mode 100644 mmdet2trt/core/bbox/iou_calculators/__init__.py rename mmdet2trt/models/{utils => layers}/__init__.py (100%) rename mmdet2trt/models/{utils => layers}/position_encoding.py (92%) create mode 100644 mmdet2trt/models/task_modules/__init__.py rename mmdet2trt/{core/bbox/coder => models/task_modules/coders}/__init__.py (100%) rename mmdet2trt/{core/bbox/coder => models/task_modules/coders}/bucketing_bbox_coder.py (96%) rename mmdet2trt/{core/bbox/coder => models/task_modules/coders}/delta_xywh_bbox_coder.py (97%) rename mmdet2trt/{core/bbox/coder => models/task_modules/coders}/tblr_bbox_coder.py (95%) rename mmdet2trt/{core/bbox/coder => models/task_modules/coders}/transforms.py (100%) rename mmdet2trt/{core/bbox/coder => models/task_modules/coders}/yolo_bbox_coder.py (90%) rename mmdet2trt/{core/anchor => models/task_modules/prior_generators}/__init__.py (100%) rename mmdet2trt/{core/anchor => models/task_modules/prior_generators}/anchor_generator.py (88%) rename mmdet2trt/{core/anchor => models/task_modules/prior_generators}/point_generator.py (91%) create mode 100644 mmdet2trt/structures/__init__.py create mode 100644 mmdet2trt/structures/bbox/__init__.py rename mmdet2trt/{core/bbox/iou_calculators/iou2d_calculator.py => structures/bbox/bbox_overlaps.py} (100%) rename mmdet2trt/{core => structures}/bbox/transforms.py (99%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3e43b31..dd81a63 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ repos: hooks: - id: flake8 - repo: https://github.com/timothycrosley/isort - rev: 5.10.1 + rev: 4.3.21 hooks: - id: isort - repo: https://github.com/pre-commit/mirrors-yapf diff --git a/README.md b/README.md index 7f38d03..f256ccc 100644 --- a/README.md +++ b/README.md @@ -1,38 +1,39 @@ # MMDet to TensorRT +> [!NOTE] +> +> The main branch is used to support model conversion of MMDetection>=3.0. +> If you want to convert model on older MMDetection, Please switch to branch: +> - [mmdet2trt=v0.5.0](https://github.com/grimoire/mmdetection-to-tensorrt/tree/v0.5.0) +> - [torch2trt_dynamic=v0.5.0](https://github.com/grimoire/torch2trt_dynamic/tree/v0.5.0) +> - [amirstan_plugin=v0.5.0](https://github.com/grimoire/amirstan_plugin/tree/v0.5.0). + ## News -OpenMMLab has release [MMDeploy](https://github.com/open-mmlab/mmdeploy) which support more inference engine and repos. PRs and advices are welcome ! +- 2024.02: Support MMDetection>=3.0 ## Introduction -This project aims to convert the mmdetection model to TensorRT model end2end. -Focus on object detection for now. +This project aims to support End2End deployment of models in MMDetection with TensorRT. + Mask support is **experiment**. -support: +Features: - fp16 - int8(experiment) - batched input - dynamic input shape - combination of different modules -- deepstream support - -Any advices, bug reports and stars are welcome. - -## License - -This project is released under the [Apache 2.0 license](LICENSE). +- DeepStream ## Requirement -- install mmdetection: +- install MMDetection: ```bash - # mim is so cool! pip install openmim - mim install mmdet==2.14.0 + mim install mmdet==3.3.0 ``` - install [torch2trt_dynamic](https://github.com/grimoire/torch2trt_dynamic): @@ -40,7 +41,7 @@ This project is released under the [Apache 2.0 license](LICENSE). ```bash git clone https://github.com/grimoire/torch2trt_dynamic.git torch2trt_dynamic cd torch2trt_dynamic - python setup.py develop + pip install -e . ``` - install [amirstan_plugin](https://github.com/grimoire/amirstan_plugin): @@ -57,11 +58,13 @@ This project is released under the [Apache 2.0 license](LICENSE). make -j10 ``` - - **DON'T FORGET** setting the envoirment variable(in `~/.bashrc`): - - ```bash - export AMIRSTAN_LIBRARY_PATH=${amirstan_plugin_root}/build/lib - ``` + > [!NOTE] + > + > **DON'T FORGET** setting the environment variable(in `~/.bashrc`): + > + > ```bash + > export AMIRSTAN_LIBRARY_PATH=${amirstan_plugin_root}/build/lib + > ``` ## Installation @@ -70,7 +73,7 @@ This project is released under the [Apache 2.0 license](LICENSE). ```bash git clone https://github.com/grimoire/mmdetection-to-tensorrt.git cd mmdetection-to-tensorrt -python setup.py develop +pip install -e . ``` ### Docker @@ -78,17 +81,9 @@ python setup.py develop Build docker image ```bash -# cuda11.1 TensorRT7.2.2 pytorch1.8 cuda11.1 sudo docker build -t mmdet2trt_docker:v1.0 docker/ ``` -You can also specify CUDA, Pytorch and Torchvision versions with docker build args by: - -```bash -# cuda11.1 tensorrt7.2.2 pytorch1.6 cuda10.2 -sudo docker build -t mmdet2trt_docker:v1.0 --build-arg TORCH_VERSION=1.6.0 --build-arg TORCHVISION_VERSION=0.7.0 --build-arg CUDA=10.2 --docker/ -``` - Run (will show the help for the CLI entrypoint) ```bash @@ -109,12 +104,13 @@ sudo docker run --gpus all -it --rm -v ${your_data_path}:${bind_path} mmdet2trt_ ## Usage -how to create a TensorRT model from mmdet model (converting might take few minutes)(Might have some warning when converting.) +Create a TensorRT model from mmdet model. detail can be found in [getting_started.md](./docs/getting_started.md) ### CLI ```bash +# conversion might take few minutes. mmdet2trt ${CONFIG_PATH} ${CHECKPOINT_PATH} ${OUTPUT_PATH} ``` @@ -123,15 +119,17 @@ Run mmdet2trt -h for help on optional arguments. ### Python ```python -opt_shape_param=[ - [ - [1,3,320,320], # min shape - [1,3,800,1344], # optimize shape - [1,3,1344,1344], # max shape - ] -] -max_workspace_size=1<<30 # some module and tactic need large workspace. -trt_model = mmdet2trt(cfg_path, weight_path, opt_shape_param=opt_shape_param, fp16_mode=True, max_workspace_size=max_workspace_size) +shape_ranges=dict( + x=dict( + min=[1,3,320,320], + opt=[1,3,800,1344], + max=[1,3,1344,1344], + ) +) +trt_model = mmdet2trt(cfg_path, + weight_path, + shape_ranges=shape_ranges, + fp16_mode=True) # save converted model torch.save(trt_model.state_dict(), save_model_path) @@ -141,12 +139,13 @@ with open(save_engine_path, mode='wb') as f: f.write(trt_model.state_dict()['engine']) ``` -**Note**: -- The input of the engine is the tensor **after preprocess**. -- The output of the engine is `num_dets, bboxes, scores, class_ids`. if you enable the `enable_mask` flag, there will be another output `mask`. -- The bboxes output of the engine did not divided by `scale factor`. +> [!NOTE] +> +> The input of the engine is the tensor **after preprocess**. +> The output of the engine is `num_dets, bboxes, scores, class_ids`. if you enable the `enable_mask` flag, there will be another output `mask`. +> The bboxes output of the engine did not divided by `scale_factor`. -how to use the converted model +how to perform inference with the converted model. ```python from mmdet.apis import inference_detector @@ -157,14 +156,6 @@ trt_detector = create_wrap_detector(trt_model, cfg_path, device_id) # result share same format as mmdetection result = inference_detector(trt_detector, image_path) - -# visualize -trt_detector.show_result( - image_path, - result, - score_thr=score_thr, - win_name='mmdet2trt', - show=True) ``` Try demo in `demo/inference.py`, or `demo/cpp` if you want to do inference with c++ api. @@ -178,6 +169,11 @@ Read [how-does-it-work](https://github.com/NVIDIA-AI-IOT/torch2trt#how-does-it-w ## Support Model/Module +> [!NOTE] +> +> Some models have only been tested on MMDet<3.0. If you found any failed model, +> Please report in the issue. + - [x] Faster R-CNN - [x] Cascade R-CNN - [x] Double-Head R-CNN @@ -218,19 +214,15 @@ Read [how-does-it-work](https://github.com/NVIDIA-AI-IOT/torch2trt#how-does-it-w Tested on: -- torch=1.8.1 -- tensorrt=8.0.1.6 -- mmdetection=2.18.0 -- cuda=11.1 - -If you find any error, please report it in the issue. +- torch=2.2.0 +- tensorrt=8.6.1 +- mmdetection=3.3.0 +- cuda=11.7 ## FAQ read [this page](./docs/FAQ.md) if you meet any problem. -## Contact - -This repo is maintained by [@grimoire](https://github.com/grimoire) +## License -And send your resume to my e-mail if you want to join @OpenMMLab. Please read the JD for detail: [link](https://mp.weixin.qq.com/s/CzrOqITFZX-T_Kcor0hs2g) +This project is released under the [Apache 2.0 license](LICENSE). diff --git a/demo/cpp/README.md b/demo/cpp/README.md index 30527f4..5769747 100644 --- a/demo/cpp/README.md +++ b/demo/cpp/README.md @@ -3,9 +3,10 @@ Best to be used within built docker container (use provided in the project Dockerfile to build the image). ## Requirements + The sample needs additional installation of opencv: -``` +```bash apt-get install -y libopencv-dev ``` @@ -13,7 +14,7 @@ apt-get install -y libopencv-dev Within -``` +```bash mkdir build & cd build cmake -Damirstan_plugin_root= .. make -j4 @@ -21,7 +22,7 @@ make -j4 ## Run the sample -``` +```bash build/trt_sample ``` @@ -31,7 +32,7 @@ The sample is implemented for the TensorRT model converted from mmdetection DCNv To obtain converted model and serialized built TensorRT engine run following command within provided docker container (~/space folder): -``` +```bash mmdet2trt --save-engine=true \ --min-scale 1 3 320 320 \ --opt-scale 1 3 544 960 \ diff --git a/demo/inference.py b/demo/inference.py index 88b0bd9..4669cf6 100644 --- a/demo/inference.py +++ b/demo/inference.py @@ -1,10 +1,12 @@ from argparse import ArgumentParser import torch -from mmdet.apis import inference_detector - from mmdet2trt import mmdet2trt from mmdet2trt.apis import create_wrap_detector +from mmdet.apis import inference_detector +from mmdet.registry import VISUALIZERS + +import mmcv def main(): @@ -33,12 +35,19 @@ def main(): result = inference_detector(trt_detector, image_path) - trt_detector.show_result( - image_path, - result, - score_thr=args.score_thr, - win_name='mmdet2trt_demo', - show=True) + # visualize + visualizer_cfg = dict(type='DetLocalVisualizer', name='visualizer') + visualizer = VISUALIZERS.build(visualizer_cfg) + visualizer.dataset_meta = trt_detector.dataset_meta + + image = mmcv.imread(image_path) + visualizer.add_datasample( + 'result', + mmcv.imconvert(image, 'bgr', 'rgb'), + data_sample=result, + draw_gt=False, + show=True, + pred_score_thr=args.score_thr) if __name__ == '__main__': diff --git a/docker/Dockerfile b/docker/Dockerfile index 1aefc2e..017b951 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,8 +1,8 @@ -FROM nvcr.io/nvidia/tensorrt:20.12-py3 +FROM nvcr.io/nvidia/tensorrt:24.01-py3 -ARG CUDA=11.1 -ARG TORCH_VERSION=1.8.0 -ARG TORCHVISION_VERSION=0.9.0 +ARG CUDA=12.1 +ARG TORCH_VERSION=2.2.0 +ARG TORCHVISION_VERSION=0.17.0 ENV DEBIAN_FRONTEND=noninteractive @@ -22,16 +22,8 @@ RUN pip3 install torch==${TORCH_VERSION}+cu${CUDA//./} torchvision==${TORCHVISIO ### install mmcv -RUN pip3 install pytest-runner -RUN pip3 install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu${CUDA//./}/torch${TORCH_VERSION}/index.html - -### git mmdetection -RUN git clone --depth=1 https://github.com/open-mmlab/mmdetection.git /root/space/mmdetection - -### install mmdetection -RUN cd /root/space/mmdetection &&\ - pip3 install -r requirements.txt &&\ - python3 setup.py develop +RUN pip3 install openmim &&\ + mim install mmdet==3.3.0 ### git amirstan plugin RUN git clone --depth=1 https://github.com/grimoire/amirstan_plugin.git /root/space/amirstan_plugin &&\ diff --git a/docs/FAQ.md b/docs/FAQ.md index 3b662d6..5d3bc95 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -6,9 +6,9 @@ - [Model Inference](#model-inference) - [**Q: Inference take a long time on a single image.**](#q-inference-take-a-long-time-on-a-single-image) - [**Q: Memory leak when inference.**](#q-memory-leak-when-inference) - - [**Q: error: parameter check failed at: engine.cpp::setBindingDimensions::1046, condition: profileMinDims.d[i] <= dimensions.d[i]**](#q-error-parameter-check-failed-at-enginecppsetbindingdimensions1046-condition-profilemindimsdi--dimensionsdi) + - [**Q: error: parameter check failed at: engine.cpp::setBindingDimensions::1046, condition: profileMinDims.d\[i\] \<= dimensions.d\[i\]**](#q-error-parameter-check-failed-at-enginecppsetbindingdimensions1046-condition-profilemindimsdi--dimensionsdi) - [**Q: FP16 model is slower than FP32 model**](#q-fp16-model-is-slower-than-fp32-model) - - [**Q: error: [TensorRT] INTERNAL ERROR: Assertion failed: cublasStatus == CUBLAS_STATUS_SUCCESS**](#q-error-tensorrt-internal-error-assertion-failed-cublasstatus--cublas_status_success) + - [**Q: error: \[TensorRT\] INTERNAL ERROR: Assertion failed: cublasStatus == CUBLAS\_STATUS\_SUCCESS**](#q-error-tensorrt-internal-error-assertion-failed-cublasstatus--cublas_status_success) This page provides some frequently asked questions and their solutions. @@ -33,13 +33,13 @@ This is a bug of on old version TensorRT, read [this](https://forums.developer.n The input tensor shape is out of the range. Please enlarge the `opt_shape_param` when converting the model. ```python - opt_shape_param=[ - [ - [1,3,224,224], # min tensor shape - [1,3,800,1312], # shape used to do int8 calib - [1,3,1344,1344], # max tensor shape - ] - ] +shape_ranges=dict( + x=dict( + min=[1,3,320,320], + opt=[1,3,800,1344], + max=[1,3,1344,1344], + ) +) ``` ### **Q: FP16 model is slower than FP32 model** diff --git a/docs/getting_started.md b/docs/getting_started.md index 9e44078..b053a09 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -13,26 +13,26 @@ This page provides details about mmdet2trt. ## dynamic shape/batched input -`opt_shape_param` is used to set the min/optimize/max shape of the input tensor. For each dimension in it, min<=optimize<=max. For example: +`shape_ranges` is used to set the min/optimize/max shape of the input tensor. For each dimension in it, min<=optimize<=max. For example: ```python -opt_shape_param=[ - [ - [1,3,320,320], # min shape - [2,3,800,1312], # opt shape - [4,3,1344,1344], # max shape - ] -] +shape_ranges=dict( + x=dict( + min=[1,3,320,320], + opt=[1,3,800,1344], + max=[1,3,1344,1344], + ) +) trt_model = mmdet2trt( ..., - opt_shape_param=opt_shape_param, # set the opt shape + shape_ranges=shape_ranges, # set the opt shape ...) ``` This config will give you input tensor size between (320, 320) to (1344, 1344), max batch_size=4 -**Warning:** - -Dynamic input shape and batch support might need more memory. Use fixed shape to avoid unnecessary memory usage(min=optimize=max). +> [!WARNING] +> +> Dynamic input shape and batch support might need more memory. Use fixed shape to avoid unnecessary memory usage(min=optimize=max). ## fp16 support @@ -46,35 +46,33 @@ trt_model = mmdet2trt( ..., ## int8 support -**int8 mode** needs more configs. - - set `input8_mode=True`. -- provide calibrate dataset, the `__getitem__()` method of dataset should return a list of tensor with shape (C,H,W), the shape **must** be the same as `opt_shape_param[0][1][1:]` (optimize shape). The tensor should do the same preprocess as the model. There is a default dataset, you can also set your custom one. +- provide calibrate dataset, the `__getitem__()` method of dataset should return a list of tensor with shape (C,H,W), the shape **must** be the same as `shape_range['x']['opt'][1:]` (optimize shape). The tensor should do the same preprocess as the model. There is a default dataset, you can also set your custom one. - set the calibrate algorithm, support `entropy` and `minmax`. ```python from mmdet2trt import mmdet2trt, Int8CalibDataset -cfg_path="..." # mmdetection config path -model_path="..." # mmdetection checkpoint path -image_path_list = [...] # lists of image pathes -opt_shape_param=[ - [ - [...], - [...], - [...], - ] -] -calib_dataset = Int8CalibDataset(image_path_list, cfg_path, opt_shape_param) +cfg_path="..." # MMDetection config path +model_path="..." # MMDetection checkpoint path +image_path_list = [...] # lists of image paths +shape_ranges=dict( + x=dict( + min=[...], + opt=[...], + max=[...], + ) +) +calib_dataset = Int8CalibDataset(image_path_list, cfg_path, shape_ranges) trt_model = mmdet2trt(cfg_path, model_path, - opt_shape_param=opt_shape_param, + shape_ranges=shape_ranges, int8_mode=True, int8_calib_dataset=calib_dataset, int8_calib_alg="entropy") ``` -**Warning:** - -Not all models support int8 mode. +> [!WARNING] +> +> Not all models support int8 mode. ## max workspace size @@ -92,9 +90,9 @@ with open(engine_path, mode='wb') as f: Link the `${AMIRSTAN_PLUGIN_DIR}/build/lib/libamirstan_plugin.so` in your project (or you can load it in runtime). Compile and load the engine. -**Warning:** - -might need to invoke `initLibAmirstanInferPlugins()` in [amirInferPlugin.h](https://github.com/grimoire/amirstan_plugin/blob/master/include/plugin/amirInferPlugin.h) to load the plugins. +> [!WARNING] +> +> might need to invoke `initLibAmirstanInferPlugins()` in [amirInferPlugin.h](https://github.com/grimoire/amirstan_plugin/blob/master/include/plugin/amirInferPlugin.h) to load the plugins. The engine only contains inference forward. Preprocess(resize, normalize) and postprocess (divide scale factor) should be done in your project. @@ -162,4 +160,6 @@ set flag `enable_mask` to True trt_model = mmdet2trt(... , enable_mask = True) ``` -**Note**: the mask output is of shape `[batch_size, num_boxes, 28, 28]`, the post-process of masks have not been included in the model. Please implement it by yourself if you want to integrate the converted engine into your own project. +> [!NOTE] +> +> The mask output is of shape `[batch_size, num_boxes, 28, 28]`, the post-process of masks have not been included in the model. Please implement it by yourself if you want to integrate the converted engine into your own project. diff --git a/mmdet2trt/__init__.py b/mmdet2trt/__init__.py index 1f89d75..8f0c6c5 100644 --- a/mmdet2trt/__init__.py +++ b/mmdet2trt/__init__.py @@ -1,4 +1,4 @@ -from .converters import * # noqa: F401,F403 +from . import converters # noqa: F401,F403 from .mmdet2trt import Int8CalibDataset, mask_processor2trt, mmdet2trt __all__ = ['Int8CalibDataset', 'mask_processor2trt', 'mmdet2trt'] diff --git a/mmdet2trt/__main__.py b/mmdet2trt/__main__.py new file mode 100644 index 0000000..15720f0 --- /dev/null +++ b/mmdet2trt/__main__.py @@ -0,0 +1,139 @@ +import logging +from argparse import ArgumentParser +from pathlib import Path + +import torch + +from .mmdet2trt import mmdet2trt + +logger = logging.getLogger('mmdet2trt') + + +def _get_default_path(config_path): + config_path = Path(config_path) + return config_path.with_suffix('.pth').name + + +def _parse_args(): + parser = ArgumentParser() + parser.add_argument('config', help='Path to a mmdet Config file') + parser.add_argument('checkpoint', help='Path to a mmdet Checkpoint file') + parser.add_argument( + '--output', + default=None, + help='Path where tensorrt model will be saved') + parser.add_argument( + '--fp16', action='store_true', help='Enable fp16 inference') + parser.add_argument( + '--enable-mask', action='store_true', help='Enable mask output') + parser.add_argument( + '--save-engine', + action='store_true', + help='Enable saving TensorRT engine. ' + '(will be saved at Path(output).with_suffix(\'.engine\')).', + ) + parser.add_argument( + '--device', + type=str, + default='cuda:0', + help='Device used for conversion.') + parser.add_argument( + '--max-workspace-gb', + type=float, + default=None, + help='The maximum `device` (GPU) temporary memory in GB (gigabytes)' + ' which TensorRT can use at execution time.', + ) + parser.add_argument( + '--min-scale', + type=int, + nargs=4, + default=None, + help='Minimum input scale in ' + '[batch_size, channels, height, width] order.' + ' Only used if all min-scale, opt-scale and max-scale are set.', + ) + parser.add_argument( + '--opt-scale', + type=int, + nargs=4, + default=None, + help='Optimal input scale in ' + '[batch_size, channels, height, width] order.' + ' Only used if all min-scale, opt-scale and max-scale are set.', + ) + parser.add_argument( + '--max-scale', + type=int, + nargs=4, + default=None, + help='Maximum input scale in ' + '[batch_size, channels, height, width] order.' + ' Only used if all min-scale, opt-scale and max-scale are set.', + ) + parser.add_argument( + '--log-level', + default='INFO', + choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], + help='Python logging level.', + ) + parser.add_argument( + '--trt-log-level', + default='INFO', + choices=['VERBOSE', 'INFO', 'WARNING', 'ERROR'], + help='TensorRT logging level.', + ) + args = parser.parse_args() + return args + + +def _save_model(trt_model, output_path, save_engine): + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + logger.info('Saving TRT model to: {}'.format(output_path)) + torch.save(trt_model.state_dict(), output_path) + + if save_engine: + logger.info('Saving TRT model engine to: {}'.format( + output_path.with_suffix('.engine'))) + with open(output_path.with_suffix('.engine'), 'wb') as f: + f.write(trt_model.state_dict()['engine']) + + +def main(): + args = _parse_args() + + logger.setLevel(getattr(logging, args.log_level)) + + if all( + getattr(args, x) is not None + for x in ['min_scale', 'opt_scale', 'max_scale']): + shape_range = dict( + x=dict(min=args.min_scale, opt=args.opt_scale, max=args.max_scale)) + else: + shape_range = None + + work_space_size = None + if args.max_workspace_gb is not None: + work_space_size = int(args.max_workspace_gb * 1e9) + + trt_model = mmdet2trt( + args.config, + args.checkpoint, + device=args.device, + fp16_mode=args.fp16, + max_workspace_size=work_space_size, + shape_ranges=shape_range, + trt_log_level=args.trt_log_level, + enable_mask=args.enable_mask) + + output_path = args.output + if output_path is None: + output_path = _get_default_path(args.config) + + _save_model(trt_model, output_path, args.save_engine) + + +if __name__ == '__main__': + main() diff --git a/mmdet2trt/apis/inference.py b/mmdet2trt/apis/inference.py index 93e3434..50f07bf 100644 --- a/mmdet2trt/apis/inference.py +++ b/mmdet2trt/apis/inference.py @@ -1,15 +1,21 @@ import logging +from typing import List, Tuple, Union -import mmcv +import mmengine import numpy as np import torch from addict import Addict -from mmdet.core import bbox2result -from mmdet.datasets.pipelines import Compose from mmdet.models import BaseDetector from mmdet.models.roi_heads.mask_heads import FCNMaskHead +from mmdet.structures import DetDataSample, OptSampleList, SampleList +from mmengine.registry import DATASETS, MODELS, init_default_scope +from mmengine.structures import InstanceData +from torch import Tensor from torch2trt_dynamic import TRTModule +import mmcv +from mmcv.transforms import Compose + logger = logging.getLogger('mmdet2trt') @@ -22,7 +28,7 @@ def init_trt_model(trt_model_path): def inference_trt_model(model, img, cfg, device): if isinstance(cfg, str): - cfg = mmcv.Config.fromfile(cfg) + cfg = mmengine.Config.fromfile(cfg) device = torch.device(device) @@ -51,7 +57,7 @@ def inference_trt_model(model, img, cfg, device): scale_factor = torch.tensor( scale_factor, dtype=torch.float32, device=device) - with torch.no_grad(): + with torch.inference_mode(): result = model(tensor) result = list(result) result[1] = result[1] / scale_factor @@ -59,43 +65,20 @@ def inference_trt_model(model, img, cfg, device): return result -def get_classes_from_config(model_cfg): - model_cfg_str = model_cfg - if isinstance(model_cfg, str): - model_cfg = mmcv.Config.fromfile(model_cfg) +def get_dataset_meta(model_cfg): + init_default_scope(model_cfg.get('default_scope', 'mmdet')) + dataset = model_cfg.val_dataloader.dataset + dataset.lazy_init = True + dataset = DATASETS.build(model_cfg.val_dataloader.dataset) + return dataset.metainfo - from mmdet.datasets import DATASETS, build_dataset +def get_classes_from_config(model_cfg): try: - dataset = build_dataset(model_cfg) - return dataset.CLASSES - except Exception: - logger.warning( - 'Can not load dataset from config. Use default CLASSES instead.') - - module_dict = DATASETS.module_dict - data_cfg = model_cfg.data - - def get_module_from_train_val(train_val_cfg): - while train_val_cfg.type == 'RepeatDataset' or \ - train_val_cfg.type == 'MultiImageMixDataset': - train_val_cfg = train_val_cfg.dataset - return module_dict[train_val_cfg.type] - - data_cfg_type_list = ['train', 'val', 'test'] - - MODULE = None - for data_cfg_type in data_cfg_type_list: - if data_cfg_type in data_cfg: - tmp_data_cfg = data_cfg.get(data_cfg_type) - MODULE = get_module_from_train_val(tmp_data_cfg) - if 'classes' in tmp_data_cfg: - return MODULE.get_classes(tmp_data_cfg.classes) - break - - assert MODULE is not None, f'No dataset config found in: {model_cfg_str}' - - return MODULE.CLASSES + return get_dataset_meta(model_cfg)['classes'] + except Exception as e: + logger.warning('Load class names from dataset failed. with error:') + raise e class TRTDetector(BaseDetector): @@ -103,11 +86,11 @@ class TRTDetector(BaseDetector): def __init__(self, trt_module, model_cfg, device_id=0): super().__init__() - - self._dummy_param = torch.nn.Parameter(torch.tensor(0.0)) if isinstance(model_cfg, str): - model_cfg = mmcv.Config.fromfile(model_cfg) - self.CLASSES = get_classes_from_config(model_cfg) + model_cfg = mmengine.Config.fromfile(model_cfg) + init_default_scope(model_cfg.get('default_scope', 'mmdet')) + # self.CLASSES = get_classes_from_config(model_cfg) + self.dataset_meta = get_dataset_meta(model_cfg) self.cfg = model_cfg self.device_id = device_id @@ -117,93 +100,96 @@ def __init__(self, trt_module, model_cfg, device_id=0): model.load_state_dict(torch.load(trt_module)) self.model = model - def simple_test(self, img, img_metas, **kwargs): - raise NotImplementedError('This method is not implemented.') + self.data_preprocessor = self._build_data_preprocessor(model_cfg) - def aug_test(self, imgs, img_metas, **kwargs): - raise NotImplementedError('This method is not implemented.') + def _build_data_preprocessor(self, cfg): + return MODELS.build(cfg.model.data_preprocessor) def extract_feat(self, imgs): raise NotImplementedError('This method is not implemented.') - def forward_train(self, imgs, img_metas, **kwargs): + def loss(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> Union[dict, list]: raise NotImplementedError('This method is not implemented.') - def val_step(self, data, optimizer): + def _forward( + self, + batch_inputs: Tensor, + batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]: raise NotImplementedError('This method is not implemented.') - def train_step(self, data, optimizer): - raise NotImplementedError('This method is not implemented.') + @torch.inference_mode() + def predict(self, + batch_inputs: Tensor, + batch_data_samples: SampleList, + rescale: bool = True) -> SampleList: - def aforward_test(self, *, img, img_metas, **kwargs): - raise NotImplementedError('This method is not implemented.') + def __rescale_bboxes(bboxes, scale_factor): + scale_factor = bboxes.new_tensor(scale_factor)[None, :] + if scale_factor.size(-1) == 2: + scale_factor = scale_factor.repeat(1, 2) + return bboxes / scale_factor - def async_simple_test(self, img, img_metas, **kwargs): - raise NotImplementedError('This method is not implemented.') - - def forward(self, img, img_metas, *args, **kwargs): - outputs = self.forward_test(img, img_metas, *args, **kwargs) + outputs = self.forward_test(batch_inputs) batch_num_dets, batch_boxes, batch_scores, batch_labels = outputs[:4] - batch_dets = torch.cat( - [batch_boxes, batch_scores.unsqueeze(-1)], dim=-1) + batch_num_dets = batch_num_dets.cpu() batch_masks = None if len(outputs) < 5 else outputs[4] - batch_size = img[0].shape[0] - img_metas = img_metas[0] + batch_size = batch_inputs.shape[0] + results = [] - rescale = kwargs.get('rescale', True) for i in range(batch_size): + in_data_sample = batch_data_samples[i] + metainfo = in_data_sample.metainfo + num_dets = batch_num_dets[i] - dets, labels = batch_dets[i][:num_dets], batch_labels[i][:num_dets] - old_dets = dets.clone() - labels = labels.int() + bboxes = batch_boxes[i, :num_dets] + labels = batch_labels[i, :num_dets].int() + scores = batch_scores[i, :num_dets] + + old_bboxes = bboxes + if rescale: - scale_factor = img_metas[i]['scale_factor'] - - if isinstance(scale_factor, (list, tuple, np.ndarray)): - assert len(scale_factor) == 4 - scale_factor = dets.new_tensor(scale_factor)[ - None, :] # [1,4] - dets[:, :4] /= scale_factor - - if 'border' in img_metas[i]: - # offset pixel of the top-left corners between original image - # and padded/enlarged image, 'border' is used when exporting - # CornerNet and CentripetalNet to onnx - x_off = img_metas[i]['border'][2] - y_off = img_metas[i]['border'][0] - dets[:, [0, 2]] -= x_off - dets[:, [1, 3]] -= y_off - dets[:, :4] *= (dets[:, :4] > 0).astype(dets.dtype) - - dets_results = bbox2result(dets, labels, len(self.CLASSES)) + assert 'scale_factor' in metainfo + bboxes = __rescale_bboxes(bboxes, metainfo['scale_factor']) + + pred_instances = InstanceData(metainfo=metainfo) + pred_instances.scores = scores + pred_instances.bboxes = bboxes + pred_instances.labels = labels + + # if 'border' in img_metas[i]: + # # offset pixel of the top-left corners between original image + # # and padded/enlarged image, 'border' is used when exporting + # # CornerNet and CentripetalNet to onnx + # x_off = img_metas[i]['border'][2] + # y_off = img_metas[i]['border'][0] + # dets[:, [0, 2]] -= x_off + # dets[:, [1, 3]] -= y_off + # dets[:, :4] *= (dets[:, :4] > 0).astype(dets.dtype) if batch_masks is not None: - masks = batch_masks[i][:num_dets].unsqueeze(1) - masks = masks.detach().cpu().numpy() - num_classes = len(self.CLASSES) + masks = batch_masks[i, :num_dets].unsqueeze(1) class_agnostic = True - segms_results = [[] for _ in range(num_classes)] - if num_dets>0: - for i in range(batch_size): - segms_results = FCNMaskHead.get_seg_masks( - Addict( - num_classes=num_classes, - class_agnostic=class_agnostic), - masks, - old_dets, - labels, - rcnn_test_cfg=Addict(mask_thr_binary=0.5), - ori_shape=img_metas[i]['ori_shape'], - scale_factor=scale_factor, - rescale=rescale) - results.append((dets_results, segms_results)) - else: - results.append(dets_results) + if num_dets > 0: + masks = FCNMaskHead._predict_by_feat_single( + Addict(class_agnostic=class_agnostic), + masks, + old_bboxes, + labels, + img_meta=metainfo, + rcnn_test_cfg=self.cfg.model.test_cfg.rcnn, + rescale=rescale, + activate_map=True) + pred_instances.masks = masks + + out_data_sample = DetDataSample( + metainfo=metainfo, pred_instances=pred_instances) + results.append(out_data_sample) return results - def forward_test(self, imgs, *args, **kwargs): - input_data = imgs[0].contiguous() - with torch.cuda.device(self.device_id), torch.no_grad(): + def forward_test(self, imgs): + input_data = imgs.contiguous() + with torch.cuda.device(self.device_id), torch.inference_mode(): outputs = self.model(input_data) return outputs diff --git a/mmdet2trt/converters/SAConv2d.py b/mmdet2trt/converters/SAConv2d.py index 3751d5e..4838553 100644 --- a/mmdet2trt/converters/SAConv2d.py +++ b/mmdet2trt/converters/SAConv2d.py @@ -1,9 +1,10 @@ -import mmcv.cnn -import mmcv.ops import torch import torch.nn.functional as F from torch2trt_dynamic.torch2trt_dynamic import tensorrt_converter +import mmcv.cnn +import mmcv.ops + @tensorrt_converter('mmcv.ops.saconv.SAConv2d.forward', is_real=False) def convert_SAConv2d(ctx): diff --git a/mmdet2trt/converters/__init__.py b/mmdet2trt/converters/__init__.py index 3ef4b90..bb61837 100644 --- a/mmdet2trt/converters/__init__.py +++ b/mmdet2trt/converters/__init__.py @@ -1,28 +1,20 @@ -from .anchor_generator import convert_AnchorGeneratorDynamic -from .batched_nms import convert_batchednms -from .bfp_forward import convert_BFP -from .carafe import (convert_carafe_feature_reassemble, - convert_carafe_kernel_normalizer, - convert_carafe_tensor_add) -from .ConvAWS2d import convert_ConvAWS2d -from .ConvWS2d import convert_ConvWS2d -from .DeformConv import convert_DeformConv, convert_ModulatedDeformConv -from .DeformPool import convert_DeformPool -from .delta2bbox_custom import convert_delta2bbox -from .generalized_attention import convert_GeneralizeAttention -from .MaskedConv import convert_MaskedConv -from .mmcv_roi_aligin import convert_mmcv_RoIAlign +import torch # noqa: F401 + +from . import ConvAWS2d # noqa: F401 +from . import ConvWS2d # noqa: F401 +from . import DeformConv # noqa: F401 +from . import DeformPool # noqa: F401 +from . import MaskedConv # noqa: F401 +from . import RoiExtractor # noqa: F401 +from . import SAConv2d # noqa: F401 +from . import anchor_generator # noqa: F401 +from . import batched_nms # noqa: F401 +from . import bfp_forward # noqa: F401 +from . import carafe # noqa: F401 +from . import delta2bbox_custom # noqa: F401 +from . import generalized_attention # noqa: F401 +from . import mmcv_roi_aligin # noqa: F401 +from . import vfnet # noqa: F401 from .mmdet2trtOps import convert_adaptive_max_pool2d_by_input -from .RoiExtractor import convert_roiextractor -from .SAConv2d import convert_SAConv2d -from .vfnet import convert_vfnet_star_dcn_offset -__all__ = [ - 'convert_AnchorGeneratorDynamic', 'convert_batchednms', 'convert_BFP', - 'convert_carafe_feature_reassemble', 'convert_carafe_kernel_normalizer', - 'convert_carafe_tensor_add', 'convert_ConvAWS2d', 'convert_ConvWS2d', - 'convert_DeformConv', 'convert_ModulatedDeformConv', 'convert_DeformPool', - 'convert_delta2bbox', 'convert_GeneralizeAttention', 'convert_MaskedConv', - 'convert_mmcv_RoIAlign', 'convert_adaptive_max_pool2d_by_input', - 'convert_roiextractor', 'convert_SAConv2d', 'convert_vfnet_star_dcn_offset' -] +__all__ = ['convert_adaptive_max_pool2d_by_input'] diff --git a/mmdet2trt/converters/anchor_generator.py b/mmdet2trt/converters/anchor_generator.py index 9fdede7..4ddca88 100644 --- a/mmdet2trt/converters/anchor_generator.py +++ b/mmdet2trt/converters/anchor_generator.py @@ -25,26 +25,6 @@ def convert_AnchorGeneratorDynamic(ctx): 'ag_' + str(id(module)), stride=stride) else: print('no base_anchors in {}'.format(ag.generator)) - # scales = ag.scales.detach().cpu().numpy().astype(np.float32) - # ratios = ag.ratios.detach().cpu().numpy().astype(np.float32) - # scale_major = ag.scale_major - # ctr = ag.ctr - # if ctr is None: - # # center_x = -1 - # # center_y = -1 - # center_x = 0 - # center_y = 0 - # else: - # center_x, center_y = ag.ctr - - # plugin = create_gridanchordynamic_plugin("ag_" + str(id(module)), - # base_size=base_size, - # stride=stride, - # scales=scales, - # ratios=ratios, - # scale_major=scale_major, - # center_x=center_x, - # center_y=center_y) custom_layer = ctx.network.add_plugin_v2( inputs=[input_trt, base_anchors_trt], plugin=plugin) diff --git a/mmdet2trt/converters/bfp_forward.py b/mmdet2trt/converters/bfp_forward.py index 46bba72..c7b8c4f 100644 --- a/mmdet2trt/converters/bfp_forward.py +++ b/mmdet2trt/converters/bfp_forward.py @@ -1,8 +1,7 @@ +import mmdet2trt.ops as mmdet2trt_ops import torch.nn.functional as F from torch2trt_dynamic.torch2trt_dynamic import tensorrt_converter -import mmdet2trt.ops as mmdet2trt_ops - @tensorrt_converter('mmdet.models.necks.BFP.forward', is_real=False) def convert_BFP(ctx): diff --git a/mmdet2trt/converters/delta2bbox_custom.py b/mmdet2trt/converters/delta2bbox_custom.py index ee9b0c0..4f14c0d 100644 --- a/mmdet2trt/converters/delta2bbox_custom.py +++ b/mmdet2trt/converters/delta2bbox_custom.py @@ -5,8 +5,8 @@ from .plugins import create_delta2bbox_custom_plugin -@tensorrt_converter( - 'mmdet2trt.core.bbox.coder.delta_xywh_bbox_coder.delta2bbox_custom_func') +@tensorrt_converter('mmdet2trt.models.task_modules.coders' + '.delta_xywh_bbox_coder.delta2bbox_custom_func') def convert_delta2bbox(ctx): cls_scores = get_arg(ctx, 'cls_scores', pos=0, default=None) bbox_preds = get_arg(ctx, 'bbox_preds', pos=1, default=None) diff --git a/mmdet2trt/converters/generalized_attention.py b/mmdet2trt/converters/generalized_attention.py index d06f8fb..c3154d2 100644 --- a/mmdet2trt/converters/generalized_attention.py +++ b/mmdet2trt/converters/generalized_attention.py @@ -1,12 +1,11 @@ import math +import mmdet2trt import numpy as np import torch import torch.nn.functional as F from torch2trt_dynamic.torch2trt_dynamic import tensorrt_converter -import mmdet2trt - def get_position_embedding(self, x_q, diff --git a/mmdet2trt/core/__init__.py b/mmdet2trt/core/__init__.py index 987c9f2..0503d7c 100644 --- a/mmdet2trt/core/__init__.py +++ b/mmdet2trt/core/__init__.py @@ -1,3 +1 @@ -from .anchor import * # noqa: F401,F403 -from .bbox import * # noqa: F401,F403 from .post_processing import * # noqa: F401,F403 diff --git a/mmdet2trt/core/bbox/__init__.py b/mmdet2trt/core/bbox/__init__.py deleted file mode 100644 index c3ec088..0000000 --- a/mmdet2trt/core/bbox/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .coder import * # noqa: F401,F403 -from .iou_calculators import * # noqa: F401,F403 -from .transforms import batched_bbox_cxcywh_to_xyxy, batched_distance2bbox - -__all__ = ['batched_bbox_cxcywh_to_xyxy', 'batched_distance2bbox'] diff --git a/mmdet2trt/core/bbox/iou_calculators/__init__.py b/mmdet2trt/core/bbox/iou_calculators/__init__.py deleted file mode 100644 index 44cda08..0000000 --- a/mmdet2trt/core/bbox/iou_calculators/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .iou2d_calculator import bbox_overlaps_batched - -__all__ = ['bbox_overlaps_batched'] diff --git a/mmdet2trt/mmdet2trt.py b/mmdet2trt/mmdet2trt.py index 68387fd..dd806f9 100644 --- a/mmdet2trt/mmdet2trt.py +++ b/mmdet2trt/mmdet2trt.py @@ -1,19 +1,14 @@ -import argparse import logging import time -from argparse import ArgumentParser -from pathlib import Path +from typing import Any, Dict -import mmcv +import mmengine import tensorrt as trt import torch -from mmdet.apis import init_detector -from mmdet.apis.inference import LoadImage -from mmdet.datasets.pipelines import Compose -from torch2trt_dynamic import torch2trt_dynamic - -from mmdet2trt.models.builder import build_wraper +from mmdet2trt.models.builder import build_wrapper from mmdet2trt.models.detectors import TwoStageDetectorWraper +from mmdet.apis import init_detector +from torch2trt_dynamic import BuildEngineConfig, module2trt logger = logging.getLogger('mmdet2trt') @@ -24,23 +19,26 @@ class Int8CalibDataset(): feed to int8_calib_dataset """ - def __init__(self, image_paths, config, opt_shape_param): + def __init__(self, image_paths, config, shape_ranges): r""" datas used to calibrate int8 model feed to int8_calib_dataset Args: image_paths (list[str]): image paths to calib config (str|dict): config of mmdetection model - opt_shape_param: same as mmdet2trt + shape_ranges: same as mmdet2trt """ + from mmcv.transforms import Compose + from mmengine.registry import init_default_scope if isinstance(config, str): - config = mmcv.Config.fromfile(config) + config = mmengine.Config.fromfile(config) + init_default_scope(config.get('default_scope', 'mmdet')) self.cfg = config self.image_paths = image_paths - self.opt_shape = opt_shape_param[0][1] + self.opt_shape = shape_ranges['x']['opt'] - test_pipeline = [LoadImage()] + config.data.test.pipeline[1:] + test_pipeline = config.val_dataloader.dataset.pipeline self.test_pipeline = Compose(test_pipeline) def __len__(self): @@ -49,51 +47,94 @@ def __len__(self): def __getitem__(self, index): image_path = self.image_paths[index] - data = dict(img=image_path) + data = dict(img=image_path, img_path=image_path) data = self.test_pipeline(data) - tensor = data['img'][0].unsqueeze(0) + tensor = data['inputs'].unsqueeze(0) tensor = torch.nn.functional.interpolate( tensor, self.opt_shape[-2:]).squeeze(0) - return [tensor] + return dict(x=tensor.cuda()) + + +def _get_shape_ranges(config): + img_scale = config.test_pipeline[1]['scale'] + min_scale = min(img_scale) + max_scale = max(img_scale) + 32 + opt_shape_param = dict( + x=dict( + min=[1, 3, min_scale, min_scale], + opt=[1, 3, img_scale[1], img_scale[0]], + max=[1, 3, max_scale, max_scale], + )) + return opt_shape_param + + +def _make_dummy_input(shape_ranges, device): + dummy_shape = shape_ranges['x']['opt'] + dummy_input = torch.rand(dummy_shape).to(device) + dummy_input = (dummy_input - 0.45) / 0.27 + dummy_input = dummy_input.contiguous() + return dummy_input + + +def _get_trt_calib_algorithm(int8_calib_alg): + int8_calib_algorithm = trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2 + if int8_calib_alg == 'minmax': + int8_calib_algorithm = trt.CalibrationAlgoType.MINMAX_CALIBRATION + elif int8_calib_alg == 'entropy': + int8_calib_algorithm = trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2 + else: + raise ValueError('int8_calib_alg should be "minmax" or "entropy"') + return int8_calib_algorithm def mmdet2trt(config, - checkpoint, - device='cuda:0', - fp16_mode=False, - int8_mode=False, - int8_calib_dataset=None, - int8_calib_alg='entropy', - max_workspace_size=0.5e9, - opt_shape_param=None, - trt_log_level='INFO', - return_wrap_model=False, - output_names=['num_detections', 'boxes', 'scores', 'classes'], + checkpoint: str, + device: str = 'cuda:0', + fp16_mode: bool = False, + int8_mode: bool = False, + int8_calib_dataset: Any = None, + int8_calib_alg: str = 'entropy', + max_workspace_size: int = None, + shape_ranges: Dict[str, Dict] = None, + trt_log_level: str = 'INFO', + return_wrap_model: bool = False, enable_mask=False): r""" - create tensorrt model from mmdetection. + create TensorRT model from MMDetection. Args: - config (str): config file path of mmdetection model - checkpoint (str): checkpoint file path of mmdetection model + config (str): config file path of MMDetection model + checkpoint (str): checkpoint file path of MMDetection model device (str): convert gpu device fp16_mode (bool): create fp16 mode engine. int8_mode (bool): create int8 mode engine. int8_calib_dataset (object): dataset object used to do data calibrate int8_calib_alg (str): how to calibrate int8, ["minmax", "entropy"] - max_workspace_size (int): tensorrt workspace size. + max_workspace_size (int): TensorRT workspace size. some tactic might need large workspace. - opt_shape_param (list[list[list[int]]]): the min/optimize/max shape of + shape_ranges (Dict[str, Dict]): the min/optimize/max shape of input tensor - trt_log_level (str): tensorrt log level, + trt_log_level (str): TensorRT log level, options: ["VERBOSE", "INFO", "WARNING", "ERROR"] return_wrap_model (bool): return pytorch wrap model, used for debug - output_names (str): the output names of tensorrt engine - enable_mask (bool): weither output the instance segmentation result - (w/o postprocess) + enable_mask (bool): if output the instance segmentation result + (w/o post-process) """ + def _make_build_engine_config(shape_ranges, max_workspace_size, + int8_calib_dataset): + int8_calib_algorithm = _get_trt_calib_algorithm(int8_calib_alg) + build_engine_config = BuildEngineConfig( + shape_ranges=shape_ranges, + pool_size=max_workspace_size, + fp16=fp16_mode, + int8=int8_mode, + int8_calib_dataset=int8_calib_dataset, + int8_calib_algorithm=int8_calib_algorithm, + int8_batch_size=1) + return build_engine_config + device = torch.device(device) logger.info('Loading model from config: {}'.format(config)) @@ -101,60 +142,37 @@ def mmdet2trt(config, cfg = torch_model.cfg logger.info('Wrapping model') - if enable_mask and output_names is not None and len(output_names) != 5: - logger.warning('mask mode require len(output_names)==5 ' + - 'but get output_names=' + str(output_names)) - output_names = None wrap_config = {'enable_mask': enable_mask} - wrap_model = build_wraper( + wrapped_model = build_wrapper( torch_model, TwoStageDetectorWraper, wrap_config=wrap_config) - if opt_shape_param is None: - img_scale = cfg.test_pipeline[1]['img_scale'] - min_scale = min(img_scale) - max_scale = max(img_scale) + 32 - opt_shape_param = [[ - [1, 3, min_scale, min_scale], - [1, 3, img_scale[1], img_scale[0]], - [1, 3, max_scale, max_scale], - ]] - - dummy_shape = opt_shape_param[0][1] - dummy_input = torch.rand(dummy_shape).to(device) - dummy_input = (dummy_input - 0.45) / 0.27 - dummy_input = dummy_input.contiguous() + if shape_ranges is None: + shape_ranges = _get_shape_ranges(cfg) + + dummy_input = _make_dummy_input(shape_ranges, device) - logger.info('Model warmup') - with torch.no_grad(): - wrap_model(dummy_input) + logger.info('Model warmup.') + with torch.cuda.device(device), torch.inference_mode(): + wrapped_model(dummy_input) logger.info('Converting model') start = time.time() - with torch.cuda.device(device), torch.no_grad(): - int8_calib_algorithm = trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2 - if int8_calib_alg == 'minmax': - int8_calib_algorithm = trt.CalibrationAlgoType.MINMAX_CALIBRATION - elif int8_calib_alg == 'entropy': - int8_calib_algorithm = \ - trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2 - trt_model = torch2trt_dynamic( - wrap_model, [dummy_input], - log_level=getattr(trt.Logger, trt_log_level), - fp16_mode=fp16_mode, - opt_shape_param=opt_shape_param, - max_workspace_size=int(max_workspace_size), - keep_network=False, - strict_type_constraints=True, - output_names=output_names, - int8_mode=int8_mode, - int8_calib_dataset=int8_calib_dataset, - int8_calib_algorithm=int8_calib_algorithm) + with torch.cuda.device(device), torch.inference_mode(): + trt_log_level = getattr(trt.Logger, trt_log_level) + build_engine_config = _make_build_engine_config( + shape_ranges=shape_ranges, + max_workspace_size=max_workspace_size, + int8_calib_dataset=int8_calib_dataset) + trt_model = module2trt( + wrapped_model, [dummy_input], + config=build_engine_config, + log_level=trt_log_level) duration = time.time() - start logger.info('Conversion took {} s'.format(duration)) if return_wrap_model: - return trt_model, wrap_model + return trt_model, wrapped_model return trt_model @@ -166,16 +184,15 @@ def mask_processor2trt(max_width, mask_size=[28, 28], device='cuda:0', fp16_mode=False, - max_workspace_size=0.5e9, + max_workspace_size=None, trt_log_level='INFO', - return_wrap_model=False, - output_names=None): + return_wrap_model=False): from mmdet2trt.models.roi_heads.mask_heads.fcn_mask_head import \ MaskProcessor logger.info('Wrapping MaskProcessor') - wrap_model = MaskProcessor(max_width=max_width, max_height=max_height) + wrapped_model = MaskProcessor(max_width=max_width, max_height=max_height) batch_size = max_batch_size num_boxes = max_box_per_batch @@ -205,145 +222,20 @@ def mask_processor2trt(max_width, logger.info('Converting MaskProcessor') start = time.time() with torch.cuda.device(device), torch.no_grad(): - trt_model = torch2trt_dynamic( - wrap_model, [dummy_mask, dummy_box], - log_level=getattr(trt.Logger, trt_log_level), - fp16_mode=fp16_mode, - opt_shape_param=opt_shape_param, - max_workspace_size=int(max_workspace_size), - keep_network=False, - strict_type_constraints=True, - output_names=output_names) + trt_log_level = getattr(trt.Logger, trt_log_level) + build_engine_config = BuildEngineConfig( + shape_ranges=opt_shape_param, + pool_size=max_workspace_size, + fp16=fp16_mode) + trt_model = module2trt( + wrapped_model, [dummy_mask, dummy_box], + config=build_engine_config, + log_level=trt_log_level) duration = time.time() - start logger.info('Conversion took {} s'.format(duration)) if return_wrap_model: - return trt_model, wrap_model + return trt_model, wrapped_model return trt_model - - -def str2bool(v): - if isinstance(v, bool): - return v - if v.lower() in ('yes', 'true', 't', 'y', '1'): - return True - elif v.lower() in ('no', 'false', 'f', 'n', '0'): - return False - else: - raise argparse.ArgumentTypeError('Boolean value expected.') - - -def main(): - parser = ArgumentParser() - parser.add_argument('config', help='Path to a mmdet Config file') - parser.add_argument('checkpoint', help='Path to a mmdet Checkpoint file') - parser.add_argument( - 'output', help='Path where tensorrt model will be saved') - parser.add_argument( - '--fp16', type=str2bool, default=False, help='Enable fp16 inference') - parser.add_argument( - '--enable-mask', - type=str2bool, - default=False, - help='Enable mask output') - parser.add_argument( - '--save-engine', - type=str2bool, - default=False, - help='Enable saving TensorRT engine. ' - '(will be saved at Path(output).with_suffix(\'.engine\')).', - ) - parser.add_argument( - '--device', - type=str, - default='cuda:0', - help='Device used for conversion.') - parser.add_argument( - '--max-workspace-gb', - type=float, - default=0.5, - help='The maximum `device` (GPU) temporary memory in GB (gigabytes)' - ' which TensorRT can use at execution time.', - ) - parser.add_argument( - '--min-scale', - type=int, - nargs=4, - default=None, - help='Minimum input scale in ' - '[batch_size, channels, height, width] order.' - ' Only used if all min-scale, opt-scale and max-scale are set.', - ) - parser.add_argument( - '--opt-scale', - type=int, - nargs=4, - default=None, - help='Optimal input scale in ' - '[batch_size, channels, height, width] order.' - ' Only used if all min-scale, opt-scale and max-scale are set.', - ) - parser.add_argument( - '--max-scale', - type=int, - nargs=4, - default=None, - help='Maximum input scale in ' - '[batch_size, channels, height, width] order.' - ' Only used if all min-scale, opt-scale and max-scale are set.', - ) - parser.add_argument( - '--log-level', - default='INFO', - choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], - help='Python logging level.', - ) - parser.add_argument( - '--trt-log-level', - default='INFO', - choices=['VERBOSE', 'INFO', 'WARNING', 'ERROR'], - help='TensorRT logging level.', - ) - parser.add_argument( - '--output-names', - nargs=4, - type=str, - default=['num_detections', 'boxes', 'scores', 'classes'], - help='Names for the output nodes of the created TRTModule', - ) - args = parser.parse_args() - - logger.setLevel(getattr(logging, args.log_level)) - - if all( - getattr(args, x) is not None - for x in ['min_scale', 'opt_scale', 'max_scale']): - opt_shape_param = [[args.min_scale, args.opt_scale, args.max_scale]] - else: - opt_shape_param = None - - trt_model = mmdet2trt( - args.config, - args.checkpoint, - device=args.device, - fp16_mode=args.fp16, - max_workspace_size=int(args.max_workspace_gb * 1e9), - opt_shape_param=opt_shape_param, - trt_log_level=args.trt_log_level, - output_names=args.output_names, - enable_mask=args.enable_mask) - - logger.info('Saving TRT model to: {}'.format(args.output)) - torch.save(trt_model.state_dict(), args.output) - - if args.save_engine: - logger.info('Saving TRT model engine to: {}'.format( - Path(args.output).with_suffix('.engine'))) - with open(Path(args.output).with_suffix('.engine'), 'wb') as f: - f.write(trt_model.state_dict()['engine']) - - -if __name__ == '__main__': - main() diff --git a/mmdet2trt/models/__init__.py b/mmdet2trt/models/__init__.py index 118c9bf..09ace4d 100644 --- a/mmdet2trt/models/__init__.py +++ b/mmdet2trt/models/__init__.py @@ -1,9 +1,10 @@ -from .backbones import * # noqa: F401,F403 -from .builder import build_wraper, register_wraper -from .dense_heads import * # noqa: F401,F403 -from .detectors import * # noqa: F401,F403 -from .necks import * # noqa: F401,F403 -from .roi_heads import * # noqa: F401,F403 -from .utils import * # noqa: F401,F403 +from . import backbones # noqa: F401,F403 +from . import dense_heads # noqa: F401,F403 +from . import detectors # noqa: F401,F403 +from . import layers # noqa: F401,F403 +from . import necks # noqa: F401,F403 +from . import roi_heads # noqa: F401,F403 +from . import task_modules # noqa: F401,F403 +from .builder import build_wrapper, register_wrapper -__all__ = ['build_wraper', 'register_wraper'] +__all__ = ['build_wrapper', 'register_wrapper'] diff --git a/mmdet2trt/models/backbones/base_backbone.py b/mmdet2trt/models/backbones/base_backbone.py index b4cf61a..1731a4f 100644 --- a/mmdet2trt/models/backbones/base_backbone.py +++ b/mmdet2trt/models/backbones/base_backbone.py @@ -1,17 +1,16 @@ import torch.nn as nn +from mmdet2trt.models.builder import register_wrapper -from mmdet2trt.models.builder import register_wraper - -@register_wraper('mmdet.models.backbones.CSPDarknet') -@register_wraper('mmdet.models.backbones.MobileNetV2') -@register_wraper('mmdet.models.backbones.ResNet') -@register_wraper('mmdet.models.backbones.SSDVGG') -@register_wraper('mmdet.models.backbones.HRNet') -@register_wraper('mmdet.models.backbones.Darknet') -@register_wraper('mmdet.models.backbones.DetectoRS_ResNet') -@register_wraper('mmdet.models.backbones.HourglassNet') -@register_wraper('mmdet.models.backbones.resnext.ResNeXt') +@register_wrapper('mmdet.models.backbones.CSPDarknet') +@register_wrapper('mmdet.models.backbones.MobileNetV2') +@register_wrapper('mmdet.models.backbones.ResNet') +@register_wrapper('mmdet.models.backbones.SSDVGG') +@register_wrapper('mmdet.models.backbones.HRNet') +@register_wrapper('mmdet.models.backbones.Darknet') +@register_wrapper('mmdet.models.backbones.DetectoRS_ResNet') +@register_wrapper('mmdet.models.backbones.HourglassNet') +@register_wrapper('mmdet.models.backbones.resnext.ResNeXt') class BaseBackboneWraper(nn.Module): def __init__(self, module): diff --git a/mmdet2trt/models/builder.py b/mmdet2trt/models/builder.py index 7258704..ba3fbb3 100644 --- a/mmdet2trt/models/builder.py +++ b/mmdet2trt/models/builder.py @@ -1,15 +1,16 @@ import logging -import mmcv.ops # noqa: F401,F403 import mmdet # noqa: F401,F403 from mmdet import models # noqa: F401,F403 +import mmcv.ops # noqa: F401,F403 + WRAPER_DICT = {} logger = logging.getLogger('mmdet2trt') -def register_wraper(module_name): +def register_wrapper(module_name): try: mmdet_module = eval(module_name) @@ -32,7 +33,7 @@ def register_func(wrap_cls): return register_func -def build_wraper(module, default_wraper=None, **kwargs): +def build_wrapper(module, default_wraper=None, **kwargs): model_type = module.__class__ wrap_model = None diff --git a/mmdet2trt/models/dense_heads/anchor_free_head.py b/mmdet2trt/models/dense_heads/anchor_free_head.py index 23a6c76..7c2447d 100644 --- a/mmdet2trt/models/dense_heads/anchor_free_head.py +++ b/mmdet2trt/models/dense_heads/anchor_free_head.py @@ -1,7 +1,6 @@ -from torch import nn - import mmdet2trt.ops as mm2trt_ops from mmdet2trt.core.post_processing.batched_nms import BatchedNMS +from torch import nn class AnchorFreeHeadWraper(nn.Module): diff --git a/mmdet2trt/models/dense_heads/anchor_head.py b/mmdet2trt/models/dense_heads/anchor_head.py index 2d7af14..1bed260 100644 --- a/mmdet2trt/models/dense_heads/anchor_head.py +++ b/mmdet2trt/models/dense_heads/anchor_head.py @@ -1,24 +1,23 @@ -import torch -from torch import nn - import mmdet2trt.ops.util_ops as mm2trt_util +import torch from mmdet2trt.core.post_processing.batched_nms import BatchedNMS -from mmdet2trt.models.builder import build_wraper, register_wraper +from mmdet2trt.models.builder import build_wrapper, register_wrapper +from torch import nn -@register_wraper('mmdet.models.dense_heads.FSAFHead') -@register_wraper('mmdet.models.RetinaSepBNHead') -@register_wraper('mmdet.models.FreeAnchorRetinaHead') -@register_wraper('mmdet.models.RetinaHead') -@register_wraper('mmdet.models.SSDHead') -@register_wraper('mmdet.models.AnchorHead') +@register_wrapper('mmdet.models.dense_heads.FSAFHead') +@register_wrapper('mmdet.models.RetinaSepBNHead') +@register_wrapper('mmdet.models.FreeAnchorRetinaHead') +@register_wrapper('mmdet.models.RetinaHead') +@register_wrapper('mmdet.models.SSDHead') +@register_wrapper('mmdet.models.AnchorHead') class AnchorHeadWraper(nn.Module): def __init__(self, module): super(AnchorHeadWraper, self).__init__() self.module = module - self.anchor_generator = build_wraper(self.module.anchor_generator) - self.bbox_coder = build_wraper(self.module.bbox_coder) + self.prior_generator = build_wrapper(self.module.prior_generator) + self.bbox_coder = build_wrapper(self.module.bbox_coder) self.test_cfg = module.test_cfg self.num_classes = self.module.num_classes @@ -33,7 +32,7 @@ def forward(self, feat, x): cls_scores, bbox_preds = module(feat) - mlvl_anchors = self.anchor_generator( + mlvl_anchors = self.prior_generator( cls_scores, device=cls_scores[0].device) mlvl_scores = [] diff --git a/mmdet2trt/models/dense_heads/atss_head.py b/mmdet2trt/models/dense_heads/atss_head.py index 6fa474e..b2251af 100644 --- a/mmdet2trt/models/dense_heads/atss_head.py +++ b/mmdet2trt/models/dense_heads/atss_head.py @@ -1,12 +1,11 @@ -import torch - import mmdet2trt.ops.util_ops as mm2trt_util -from mmdet2trt.models.builder import register_wraper +import torch +from mmdet2trt.models.builder import register_wrapper from .anchor_head import AnchorHeadWraper -@register_wraper('mmdet.models.dense_heads.ATSSHead') +@register_wrapper('mmdet.models.dense_heads.ATSSHead') class ATSSHeadWraper(AnchorHeadWraper): def __init__(self, module): @@ -17,7 +16,7 @@ def forward(self, feat, x): cls_scores, bbox_preds, centernesses = module(feat) - mlvl_anchors = self.anchor_generator( + mlvl_anchors = self.prior_generator( cls_scores, device=cls_scores[0].device) mlvl_scores = [] diff --git a/mmdet2trt/models/dense_heads/cascade_rpn_head.py b/mmdet2trt/models/dense_heads/cascade_rpn_head.py index 96d7e59..87f0da7 100644 --- a/mmdet2trt/models/dense_heads/cascade_rpn_head.py +++ b/mmdet2trt/models/dense_heads/cascade_rpn_head.py @@ -1,13 +1,12 @@ +import mmdet2trt.ops.util_ops as mm2trt_util import torch +from mmdet2trt.models.builder import build_wrapper, register_wrapper from torch import nn -import mmdet2trt.ops.util_ops as mm2trt_util -from mmdet2trt.models.builder import build_wraper, register_wraper - from .rpn_head import RPNHeadWraper -@register_wraper('mmdet.models.dense_heads.StageCascadeRPNHead') +@register_wrapper('mmdet.models.dense_heads.StageCascadeRPNHead') class StageCascadeRPNHeadWraper(RPNHeadWraper): def __init__(self, module): @@ -26,7 +25,7 @@ def forward(self, x, offset_list): return self.module(x, offset_list) def get_anchors(self, featmaps, device='cuda'): - return self.anchor_generator(featmaps, device=device) + return self.prior_generator(featmaps, device=device) def anchor_offset(self, anchor_list, anchor_strides, featmap_sizes): @@ -137,13 +136,13 @@ def get_bboxes(self, mlvl_anchors, cls_scores, bbox_preds, img_metas, return proposals -@register_wraper('mmdet.models.dense_heads.CascadeRPNHead') +@register_wrapper('mmdet.models.dense_heads.CascadeRPNHead') class CascadeRPNHeadWraper(nn.Module): def __init__(self, module): super(CascadeRPNHeadWraper, self).__init__() self.module = module - self.stages = [build_wraper(stage) for stage in self.module.stages] + self.stages = [build_wrapper(stage) for stage in self.module.stages] self.test_cfg = module.test_cfg def forward(self, feat, x): diff --git a/mmdet2trt/models/dense_heads/centripetal_head.py b/mmdet2trt/models/dense_heads/centripetal_head.py index f6eb92a..cab9b1d 100644 --- a/mmdet2trt/models/dense_heads/centripetal_head.py +++ b/mmdet2trt/models/dense_heads/centripetal_head.py @@ -1,11 +1,10 @@ import torch - -from mmdet2trt.models.builder import register_wraper +from mmdet2trt.models.builder import register_wrapper from .corner_head import CornerHeadWraper -@register_wraper('mmdet.models.dense_heads.CentripetalHead') +@register_wrapper('mmdet.models.dense_heads.CentripetalHead') class CentripetalHeadWraper(CornerHeadWraper): def __init__(self, module): diff --git a/mmdet2trt/models/dense_heads/corner_head.py b/mmdet2trt/models/dense_heads/corner_head.py index 5c527ee..514a952 100644 --- a/mmdet2trt/models/dense_heads/corner_head.py +++ b/mmdet2trt/models/dense_heads/corner_head.py @@ -1,12 +1,11 @@ import torch import torch.nn.functional as F -from torch import nn - from mmdet2trt.core.post_processing.batched_nms import BatchedNMS -from mmdet2trt.models.builder import register_wraper +from mmdet2trt.models.builder import register_wrapper +from torch import nn -@register_wraper('mmdet.models.CornerHead') +@register_wrapper('mmdet.models.CornerHead') class CornerHeadWraper(nn.Module): def __init__(self, module): diff --git a/mmdet2trt/models/dense_heads/detr_head.py b/mmdet2trt/models/dense_heads/detr_head.py index 903044d..4887f8b 100644 --- a/mmdet2trt/models/dense_heads/detr_head.py +++ b/mmdet2trt/models/dense_heads/detr_head.py @@ -1,19 +1,18 @@ import torch +from mmdet2trt.models.builder import build_wrapper, register_wrapper +from mmdet2trt.structures.bbox.transforms import batched_bbox_cxcywh_to_xyxy from torch import nn from torch.nn import functional as F -from mmdet2trt.core.bbox.transforms import batched_bbox_cxcywh_to_xyxy -from mmdet2trt.models.builder import build_wraper, register_wraper - -@register_wraper('mmdet.models.dense_heads.DETRHead') +@register_wrapper('mmdet.models.dense_heads.DETRHead') class DETRHeadWraper(nn.Module): def __init__(self, module): super(DETRHeadWraper, self).__init__() self.module = module self.test_cfg = module.test_cfg - self.positional_encoding = build_wraper(module.positional_encoding) + self.positional_encoding = build_wrapper(module.positional_encoding) def module_forward(self, feats, x): module = self.module diff --git a/mmdet2trt/models/dense_heads/fcos_head.py b/mmdet2trt/models/dense_heads/fcos_head.py index 8e15aa8..9461e94 100644 --- a/mmdet2trt/models/dense_heads/fcos_head.py +++ b/mmdet2trt/models/dense_heads/fcos_head.py @@ -1,13 +1,12 @@ -import torch - import mmdet2trt.core.post_processing.batched_nms as batched_nms import mmdet2trt.ops.util_ops as mm2trt_util -from mmdet2trt.core.bbox import batched_distance2bbox -from mmdet2trt.models.builder import register_wraper +import torch +from mmdet2trt.models.builder import register_wrapper from mmdet2trt.models.dense_heads.anchor_free_head import AnchorFreeHeadWraper +from mmdet2trt.structures.bbox.transforms import batched_distance2bbox -@register_wraper('mmdet.models.FCOSHead') +@register_wrapper('mmdet.models.FCOSHead') class FCOSHeadWraper(AnchorFreeHeadWraper): def __init__(self, module): diff --git a/mmdet2trt/models/dense_heads/fovea_head.py b/mmdet2trt/models/dense_heads/fovea_head.py index 86d00a8..37a926e 100644 --- a/mmdet2trt/models/dense_heads/fovea_head.py +++ b/mmdet2trt/models/dense_heads/fovea_head.py @@ -1,11 +1,10 @@ -import torch - import mmdet2trt.ops.util_ops as mm2trt_util -from mmdet2trt.models.builder import register_wraper +import torch +from mmdet2trt.models.builder import register_wrapper from mmdet2trt.models.dense_heads.anchor_free_head import AnchorFreeHeadWraper -@register_wraper('mmdet.models.FoveaHead') +@register_wrapper('mmdet.models.FoveaHead') class FoveaHeadWraper(AnchorFreeHeadWraper): def __init__(self, module): diff --git a/mmdet2trt/models/dense_heads/ga_rpn_head.py b/mmdet2trt/models/dense_heads/ga_rpn_head.py index e6a6399..0490f05 100644 --- a/mmdet2trt/models/dense_heads/ga_rpn_head.py +++ b/mmdet2trt/models/dense_heads/ga_rpn_head.py @@ -1,13 +1,12 @@ -import torch - import mmdet2trt.ops.util_ops as mm2trt_util +import torch from mmdet2trt.core.post_processing.batched_nms import BatchedNMS -from mmdet2trt.models.builder import register_wraper +from mmdet2trt.models.builder import register_wrapper from .guided_anchor_head import GuidedAnchorHeadWraper -@register_wraper('mmdet.models.GARPNHead') +@register_wrapper('mmdet.models.GARPNHead') class GARPNHeadWraper(GuidedAnchorHeadWraper): def __init__(self, module): diff --git a/mmdet2trt/models/dense_heads/gfl_head.py b/mmdet2trt/models/dense_heads/gfl_head.py index a9e2f27..8a9fedf 100644 --- a/mmdet2trt/models/dense_heads/gfl_head.py +++ b/mmdet2trt/models/dense_heads/gfl_head.py @@ -1,15 +1,14 @@ +import mmdet2trt.ops.util_ops as mm2trt_util import torch import torch.nn.functional as F - -import mmdet2trt.ops.util_ops as mm2trt_util -from mmdet2trt.core.bbox.transforms import batched_distance2bbox from mmdet2trt.core.post_processing.batched_nms import BatchedNMS -from mmdet2trt.models.builder import register_wraper +from mmdet2trt.models.builder import register_wrapper +from mmdet2trt.structures.bbox.transforms import batched_distance2bbox from .anchor_head import AnchorHeadWraper -@register_wraper('mmdet.models.GFLHead') +@register_wrapper('mmdet.models.GFLHead') class GFLHeadWraper(AnchorHeadWraper): def __init__(self, module): @@ -49,7 +48,7 @@ def forward(self, feat, x): cls_scores, bbox_preds = module(feat) num_levels = len(cls_scores) - mlvl_anchors = self.anchor_generator( + mlvl_anchors = self.prior_generator( cls_scores, device=cls_scores[0].device) mlvl_scores = [] @@ -59,7 +58,7 @@ def forward(self, feat, x): rpn_cls_score = cls_scores[idx] rpn_bbox_pred = bbox_preds[idx] anchors = mlvl_anchors[idx] - stride = module.anchor_generator.strides[idx] + stride = module.prior_generator.strides[idx] scores = rpn_cls_score.permute(0, 2, 3, 1).reshape( rpn_cls_score.shape[0], -1, module.cls_out_channels).sigmoid() bbox_pred = rpn_bbox_pred.permute(0, 2, 3, 1) diff --git a/mmdet2trt/models/dense_heads/guided_anchor_head.py b/mmdet2trt/models/dense_heads/guided_anchor_head.py index d6177a7..68f39ac 100644 --- a/mmdet2trt/models/dense_heads/guided_anchor_head.py +++ b/mmdet2trt/models/dense_heads/guided_anchor_head.py @@ -1,12 +1,11 @@ -import torch -from torch import nn - import mmdet2trt.ops.util_ops as mm2trt_util +import torch from mmdet2trt.core.post_processing.batched_nms import BatchedNMS -from mmdet2trt.models.builder import build_wraper, register_wraper +from mmdet2trt.models.builder import build_wrapper, register_wrapper +from torch import nn -@register_wraper('mmdet.models.GARetinaHead') +@register_wrapper('mmdet.models.GARetinaHead') class GuidedAnchorHeadWraper(nn.Module): def __init__(self, module): @@ -14,10 +13,10 @@ def __init__(self, module): self.module = module self.loc_filter_thr = module.loc_filter_thr self.num_anchors = module.num_anchors - self.square_anchor_generator = build_wraper( + self.square_anchor_generator = build_wrapper( self.module.square_anchor_generator) - self.anchor_coder = build_wraper(self.module.anchor_coder) - self.bbox_coder = build_wraper(self.module.bbox_coder) + self.anchor_coder = build_wrapper(self.module.anchor_coder) + self.bbox_coder = build_wrapper(self.module.bbox_coder) self.test_cfg = module.test_cfg self.num_classes = self.module.num_classes diff --git a/mmdet2trt/models/dense_heads/paa_head.py b/mmdet2trt/models/dense_heads/paa_head.py index f890ec6..8ff2c82 100644 --- a/mmdet2trt/models/dense_heads/paa_head.py +++ b/mmdet2trt/models/dense_heads/paa_head.py @@ -1,13 +1,12 @@ -import torch - import mmdet2trt.ops.util_ops as mm2trt_util -from mmdet2trt.core.bbox.iou_calculators import bbox_overlaps_batched -from mmdet2trt.models.builder import register_wraper +import torch +from mmdet2trt.models.builder import register_wrapper +from mmdet2trt.structures.bbox import bbox_overlaps_batched from .anchor_head import AnchorHeadWraper -@register_wraper('mmdet.models.dense_heads.paa_head.PAAHead') +@register_wrapper('mmdet.models.dense_heads.paa_head.PAAHead') class PPAHeadWraper(AnchorHeadWraper): def __init__(self, module): @@ -18,7 +17,7 @@ def forward(self, feat, x): cls_scores, bbox_preds, iou_preds = module(feat) - mlvl_anchors = self.anchor_generator( + mlvl_anchors = self.prior_generator( cls_scores, device=cls_scores[0].device) mlvl_scores = [] diff --git a/mmdet2trt/models/dense_heads/reppoints_head.py b/mmdet2trt/models/dense_heads/reppoints_head.py index 268c86f..31eb4e3 100644 --- a/mmdet2trt/models/dense_heads/reppoints_head.py +++ b/mmdet2trt/models/dense_heads/reppoints_head.py @@ -1,11 +1,10 @@ -import torch - import mmdet2trt.ops.util_ops as mm2trt_util -from mmdet2trt.models.builder import build_wraper, register_wraper +import torch +from mmdet2trt.models.builder import build_wrapper, register_wrapper from mmdet2trt.models.dense_heads.anchor_free_head import AnchorFreeHeadWraper -@register_wraper('mmdet.models.RepPointsHead') +@register_wrapper('mmdet.models.RepPointsHead') class RepPointsHeadWraper(AnchorFreeHeadWraper): def __init__(self, module): @@ -13,15 +12,15 @@ def __init__(self, module): if hasattr(self.module, 'prior_generator'): # mmdet 2.18 - self.prior_generator = build_wraper(self.module.prior_generator) + self.prior_generator = build_wrapper(self.module.prior_generator) elif hasattr(self.module, 'point_generators'): # mmdet 2.10 self.point_generators = [ - build_wraper(generator) + build_wrapper(generator) for generator in self.module.point_generators ] else: - self.point_generator = build_wraper(self.module.point_generator) + self.point_generator = build_wrapper(self.module.point_generator) def forward(self, feat, x): img_shape = x.shape[2:] diff --git a/mmdet2trt/models/dense_heads/rpn_head.py b/mmdet2trt/models/dense_heads/rpn_head.py index b400d21..c401c42 100644 --- a/mmdet2trt/models/dense_heads/rpn_head.py +++ b/mmdet2trt/models/dense_heads/rpn_head.py @@ -1,19 +1,18 @@ -import torch -from torch import nn - import mmdet2trt.ops.util_ops as mm2trt_util +import torch from mmdet2trt.core.post_processing.batched_nms import BatchedNMS -from mmdet2trt.models.builder import build_wraper, register_wraper +from mmdet2trt.models.builder import build_wrapper, register_wrapper +from torch import nn -@register_wraper('mmdet.models.RPNHead') +@register_wrapper('mmdet.models.RPNHead') class RPNHeadWraper(nn.Module): def __init__(self, module): super(RPNHeadWraper, self).__init__() self.module = module - self.anchor_generator = build_wraper(self.module.anchor_generator) - self.bbox_coder = build_wraper(self.module.bbox_coder) + self.prior_generator = build_wrapper(self.module.prior_generator) + self.bbox_coder = build_wrapper(self.module.bbox_coder) self.test_cfg = module.test_cfg if 'nms' in self.test_cfg: @@ -30,7 +29,7 @@ def forward(self, feat, x): cls_scores, bbox_preds = module(feat) - mlvl_anchors = self.anchor_generator( + mlvl_anchors = self.prior_generator( cls_scores, device=cls_scores[0].device) mlvl_scores = [] diff --git a/mmdet2trt/models/dense_heads/sabl_retina_head.py b/mmdet2trt/models/dense_heads/sabl_retina_head.py index 129c31b..6424c8c 100644 --- a/mmdet2trt/models/dense_heads/sabl_retina_head.py +++ b/mmdet2trt/models/dense_heads/sabl_retina_head.py @@ -1,20 +1,19 @@ -import torch -from torch import nn - import mmdet2trt.ops.util_ops as mm2trt_util +import torch from mmdet2trt.core.post_processing.batched_nms import BatchedNMS -from mmdet2trt.models.builder import build_wraper, register_wraper +from mmdet2trt.models.builder import build_wrapper, register_wrapper +from torch import nn -@register_wraper('mmdet.models.dense_heads.SABLRetinaHead') +@register_wrapper('mmdet.models.dense_heads.SABLRetinaHead') class SABLRetinaHeadWraper(nn.Module): def __init__(self, module): super(SABLRetinaHeadWraper, self).__init__() self.module = module - self.square_anchor_generator = build_wraper( + self.square_anchor_generator = build_wrapper( self.module.square_anchor_generator) - self.bbox_coder = build_wraper(self.module.bbox_coder) + self.bbox_coder = build_wrapper(self.module.bbox_coder) self.test_cfg = module.test_cfg self.num_classes = self.module.num_classes diff --git a/mmdet2trt/models/dense_heads/vfnet_head.py b/mmdet2trt/models/dense_heads/vfnet_head.py index de766b0..6c0b302 100644 --- a/mmdet2trt/models/dense_heads/vfnet_head.py +++ b/mmdet2trt/models/dense_heads/vfnet_head.py @@ -1,13 +1,12 @@ -import torch - import mmdet2trt.core.post_processing.batched_nms as batched_nms import mmdet2trt.ops.util_ops as mm2trt_util -from mmdet2trt.core.bbox import batched_distance2bbox -from mmdet2trt.models.builder import register_wraper +import torch +from mmdet2trt.models.builder import register_wrapper from mmdet2trt.models.dense_heads.anchor_free_head import AnchorFreeHeadWraper +from mmdet2trt.structures.bbox import batched_distance2bbox -@register_wraper('mmdet.models.VFNetHead') +@register_wrapper('mmdet.models.VFNetHead') class VFNetHeadWraper(AnchorFreeHeadWraper): def __init__(self, module): diff --git a/mmdet2trt/models/dense_heads/yolo_head.py b/mmdet2trt/models/dense_heads/yolo_head.py index 7c1c9d8..634dd01 100644 --- a/mmdet2trt/models/dense_heads/yolo_head.py +++ b/mmdet2trt/models/dense_heads/yolo_head.py @@ -1,19 +1,18 @@ -import torch -from torch import nn - import mmdet2trt.ops.util_ops as mm2trt_util +import torch from mmdet2trt.core.post_processing.batched_nms import BatchedNMS -from mmdet2trt.models.builder import build_wraper, register_wraper +from mmdet2trt.models.builder import build_wrapper, register_wrapper +from torch import nn -@register_wraper('mmdet.models.dense_heads.YOLOV3Head') +@register_wrapper('mmdet.models.dense_heads.YOLOV3Head') class YOLOV3HeadWraper(nn.Module): def __init__(self, module): super(YOLOV3HeadWraper, self).__init__() self.module = module - self.anchor_generator = build_wraper(self.module.anchor_generator) - self.bbox_coder = build_wraper(self.module.bbox_coder) + self.prior_generator = build_wrapper(self.module.prior_generator) + self.bbox_coder = build_wrapper(self.module.bbox_coder) self.featmap_strides = module.featmap_strides self.num_attrib = module.num_attrib self.num_levels = module.num_levels @@ -35,7 +34,7 @@ def forward(self, feats, x): pred_maps_list = module(feats)[0] - multi_lvl_anchors = self.anchor_generator( + multi_lvl_anchors = self.prior_generator( pred_maps_list, device=pred_maps_list[0].device) multi_lvl_bboxes = [] diff --git a/mmdet2trt/models/dense_heads/yolox_head.py b/mmdet2trt/models/dense_heads/yolox_head.py index 93b5c9c..fd67d59 100644 --- a/mmdet2trt/models/dense_heads/yolox_head.py +++ b/mmdet2trt/models/dense_heads/yolox_head.py @@ -1,17 +1,16 @@ import torch -from torch import nn - from mmdet2trt.core.post_processing.batched_nms import BatchedNMS -from mmdet2trt.models.builder import build_wraper, register_wraper +from mmdet2trt.models.builder import build_wrapper, register_wrapper +from torch import nn -@register_wraper('mmdet.models.dense_heads.YOLOXHead') +@register_wrapper('mmdet.models.dense_heads.YOLOXHead') class YOLOXHeadWraper(nn.Module): def __init__(self, module): super(YOLOXHeadWraper, self).__init__() self.module = module - self.prior_generator = build_wraper(self.module.prior_generator) + self.prior_generator = build_wrapper(self.module.prior_generator) self.cls_out_channels = self.module.cls_out_channels iou_thr = 0.7 if 'iou_thr' in module.test_cfg.nms: diff --git a/mmdet2trt/models/detectors/single_stage.py b/mmdet2trt/models/detectors/single_stage.py index 194207c..d07fd6d 100644 --- a/mmdet2trt/models/detectors/single_stage.py +++ b/mmdet2trt/models/detectors/single_stage.py @@ -1,24 +1,23 @@ -from torch import nn - from mmdet2trt.models.backbones import BaseBackboneWraper -from mmdet2trt.models.builder import build_wraper, register_wraper +from mmdet2trt.models.builder import build_wrapper, register_wrapper from mmdet2trt.models.necks import BaseNeckWraper +from torch import nn -@register_wraper('mmdet.models.YOLOX') -@register_wraper('mmdet.models.GFL') -@register_wraper('mmdet.models.CornerNet') -@register_wraper('mmdet.models.PAA') -@register_wraper('mmdet.models.YOLOV3') -@register_wraper('mmdet.models.FSAF') -@register_wraper('mmdet.models.ATSS') -@register_wraper('mmdet.models.RepPointsDetector') -@register_wraper('mmdet.models.FOVEA') -@register_wraper('mmdet.models.FCOS') -@register_wraper('mmdet.models.RetinaNet') -@register_wraper('mmdet.models.SingleStageDetector') -@register_wraper('mmdet.models.VFNet') -@register_wraper('mmdet.models.DETR') +@register_wrapper('mmdet.models.YOLOX') +@register_wrapper('mmdet.models.GFL') +@register_wrapper('mmdet.models.CornerNet') +@register_wrapper('mmdet.models.PAA') +@register_wrapper('mmdet.models.YOLOV3') +@register_wrapper('mmdet.models.FSAF') +@register_wrapper('mmdet.models.ATSS') +@register_wrapper('mmdet.models.RepPointsDetector') +@register_wrapper('mmdet.models.FOVEA') +@register_wrapper('mmdet.models.FCOS') +@register_wrapper('mmdet.models.RetinaNet') +@register_wrapper('mmdet.models.SingleStageDetector') +@register_wrapper('mmdet.models.VFNet') +@register_wrapper('mmdet.models.DETR') class SingleStageDetectorWraper(nn.Module): def __init__(self, model, wrap_config={}): @@ -26,26 +25,24 @@ def __init__(self, model, wrap_config={}): self.model = model mmdet_backbone = self.model.backbone - self.backbone_wraper = build_wraper(mmdet_backbone, BaseBackboneWraper) + self.backbone = build_wrapper(mmdet_backbone, BaseBackboneWraper) if self.model.with_neck: mmdet_neck = self.model.neck - self.neck_wraper = build_wraper(mmdet_neck, BaseNeckWraper) + self.neck = build_wrapper(mmdet_neck, BaseNeckWraper) mmdet_bbox_head = self.model.bbox_head - self.bbox_head_wraper = build_wraper(mmdet_bbox_head) + self.bbox_head = build_wrapper(mmdet_bbox_head) def extract_feat(self, img): - x = self.backbone_wraper(img) + x = self.backbone(img) if self.model.with_neck: - x = self.neck_wraper(x) + x = self.neck(x) return x def forward(self, x): - bbox_head = self.bbox_head_wraper - # backbone feat = self.extract_feat(x) - result = bbox_head(feat, x) + result = self.bbox_head(feat, x) return result diff --git a/mmdet2trt/models/detectors/two_stage.py b/mmdet2trt/models/detectors/two_stage.py index 660a591..c0ef194 100644 --- a/mmdet2trt/models/detectors/two_stage.py +++ b/mmdet2trt/models/detectors/two_stage.py @@ -1,19 +1,18 @@ -from torch import nn - from mmdet2trt.models.backbones import BaseBackboneWraper -from mmdet2trt.models.builder import build_wraper, register_wraper +from mmdet2trt.models.builder import build_wrapper, register_wrapper from mmdet2trt.models.dense_heads import RPNHeadWraper from mmdet2trt.models.necks import BaseNeckWraper from mmdet2trt.models.roi_heads import StandardRoIHeadWraper +from torch import nn -@register_wraper('mmdet.models.MaskScoringRCNN') -@register_wraper('mmdet.models.GridRCNN') -@register_wraper('mmdet.models.HybridTaskCascade') -@register_wraper('mmdet.models.MaskRCNN') -@register_wraper('mmdet.models.CascadeRCNN') -@register_wraper('mmdet.models.FasterRCNN') -@register_wraper('mmdet.models.TwoStageDetector') +@register_wrapper('mmdet.models.MaskScoringRCNN') +@register_wrapper('mmdet.models.GridRCNN') +@register_wrapper('mmdet.models.HybridTaskCascade') +@register_wrapper('mmdet.models.MaskRCNN') +@register_wrapper('mmdet.models.CascadeRCNN') +@register_wrapper('mmdet.models.FasterRCNN') +@register_wrapper('mmdet.models.TwoStageDetector') class TwoStageDetectorWraper(nn.Module): def __init__(self, model, wrap_config={}): @@ -21,17 +20,18 @@ def __init__(self, model, wrap_config={}): self.model = model mmdet_backbone = self.model.backbone - self.backbone_wraper = build_wraper(mmdet_backbone, BaseBackboneWraper) + self.backbone_wraper = build_wrapper(mmdet_backbone, + BaseBackboneWraper) if self.model.with_neck: mmdet_neck = self.model.neck - self.neck_wraper = build_wraper(mmdet_neck, BaseNeckWraper) + self.neck_wraper = build_wrapper(mmdet_neck, BaseNeckWraper) mmdet_rpn_head = self.model.rpn_head - self.rpn_head_wraper = build_wraper(mmdet_rpn_head, RPNHeadWraper) + self.rpn_head_wraper = build_wrapper(mmdet_rpn_head, RPNHeadWraper) mmdet_roi_head = self.model.roi_head - self.roi_head_wraper = build_wraper( + self.roi_head_wraper = build_wrapper( mmdet_roi_head, StandardRoIHeadWraper, wrap_config=wrap_config) def extract_feat(self, img): diff --git a/mmdet2trt/models/utils/__init__.py b/mmdet2trt/models/layers/__init__.py similarity index 100% rename from mmdet2trt/models/utils/__init__.py rename to mmdet2trt/models/layers/__init__.py diff --git a/mmdet2trt/models/utils/position_encoding.py b/mmdet2trt/models/layers/position_encoding.py similarity index 92% rename from mmdet2trt/models/utils/position_encoding.py rename to mmdet2trt/models/layers/position_encoding.py index 11b486b..3c2e253 100644 --- a/mmdet2trt/models/utils/position_encoding.py +++ b/mmdet2trt/models/layers/position_encoding.py @@ -1,10 +1,9 @@ import torch +from mmdet2trt.models.builder import register_wrapper from torch import nn -from mmdet2trt.models.builder import register_wraper - -@register_wraper('mmdet.models.utils.SinePositionalEncoding') +@register_wrapper('mmdet.models.layers.SinePositionalEncoding') class SinePositionalEncodingWraper(nn.Module): def __init__(self, module): diff --git a/mmdet2trt/models/necks/base_neck.py b/mmdet2trt/models/necks/base_neck.py index 5db3b1f..310750f 100644 --- a/mmdet2trt/models/necks/base_neck.py +++ b/mmdet2trt/models/necks/base_neck.py @@ -1,16 +1,15 @@ import torch.nn as nn +from mmdet2trt.models.builder import register_wrapper -from mmdet2trt.models.builder import register_wraper - -@register_wraper('mmdet.models.necks.YOLOXPAFPN') -@register_wraper('mmdet.models.necks.FPN') -@register_wraper('mmdet.models.necks.BFP') -@register_wraper('mmdet.models.necks.FPN_CARAFE') -@register_wraper('mmdet.models.necks.NASFPN') -@register_wraper('mmdet.models.necks.RFP') -@register_wraper('mmdet.models.necks.YOLOV3Neck') -@register_wraper('mmdet.models.necks.SSDNeck') +@register_wrapper('mmdet.models.necks.YOLOXPAFPN') +@register_wrapper('mmdet.models.necks.FPN') +@register_wrapper('mmdet.models.necks.BFP') +@register_wrapper('mmdet.models.necks.FPN_CARAFE') +@register_wrapper('mmdet.models.necks.NASFPN') +@register_wrapper('mmdet.models.necks.RFP') +@register_wrapper('mmdet.models.necks.YOLOV3Neck') +@register_wrapper('mmdet.models.necks.SSDNeck') class BaseNeckWraper(nn.Module): def __init__(self, module): diff --git a/mmdet2trt/models/necks/hrfpn.py b/mmdet2trt/models/necks/hrfpn.py index 9130a38..b884838 100644 --- a/mmdet2trt/models/necks/hrfpn.py +++ b/mmdet2trt/models/necks/hrfpn.py @@ -1,7 +1,6 @@ import torch.nn as nn import torch.nn.functional as F - -from mmdet2trt.models.builder import register_wraper +from mmdet2trt.models.builder import register_wrapper def pooling_wrap(pooling): @@ -20,7 +19,7 @@ def pool(*args, **kwargs): return None -@register_wraper('mmdet.models.necks.HRFPN') +@register_wrapper('mmdet.models.necks.HRFPN') class HRFPNWraper(nn.Module): def __init__(self, module): diff --git a/mmdet2trt/models/roi_heads/bbox_heads/bbox_head.py b/mmdet2trt/models/roi_heads/bbox_heads/bbox_head.py index 71d28ba..908a7a9 100644 --- a/mmdet2trt/models/roi_heads/bbox_heads/bbox_head.py +++ b/mmdet2trt/models/roi_heads/bbox_heads/bbox_head.py @@ -1,16 +1,15 @@ import torch import torch.nn.functional as F -from torch import nn - from mmdet2trt.core.post_processing.batched_nms import BatchedNMS -from mmdet2trt.models.builder import build_wraper, register_wraper +from mmdet2trt.models.builder import build_wrapper, register_wrapper +from torch import nn -@register_wraper( +@register_wrapper( 'mmdet.models.roi_heads.bbox_heads.convfc_bbox_head.ConvFCBBoxHead') -@register_wraper( +@register_wrapper( 'mmdet.models.roi_heads.bbox_heads.convfc_bbox_head.Shared2FCBBoxHead') -@register_wraper( +@register_wrapper( 'mmdet.models.roi_heads.bbox_heads.convfc_bbox_head.Shared4Conv1FCBBoxHead' ) class BBoxHeadWraper(nn.Module): @@ -19,7 +18,7 @@ def __init__(self, module, test_cfg): super(BBoxHeadWraper, self).__init__() self.module = module - self.bbox_coder = build_wraper(self.module.bbox_coder) + self.bbox_coder = build_wrapper(self.module.bbox_coder) self.test_cfg = test_cfg self.num_classes = module.num_classes self.rcnn_nms = BatchedNMS( diff --git a/mmdet2trt/models/roi_heads/bbox_heads/double_bbox_head.py b/mmdet2trt/models/roi_heads/bbox_heads/double_bbox_head.py index 0b0befd..1c0e91c 100644 --- a/mmdet2trt/models/roi_heads/bbox_heads/double_bbox_head.py +++ b/mmdet2trt/models/roi_heads/bbox_heads/double_bbox_head.py @@ -1,9 +1,9 @@ -from mmdet2trt.models.builder import register_wraper +from mmdet2trt.models.builder import register_wrapper from .bbox_head import BBoxHeadWraper -@register_wraper( +@register_wrapper( 'mmdet.models.roi_heads.bbox_heads.double_bbox_head.DoubleConvFCBBoxHead') class DoubleConvFCBBoxHeadWraper(BBoxHeadWraper): diff --git a/mmdet2trt/models/roi_heads/bbox_heads/sabl_head.py b/mmdet2trt/models/roi_heads/bbox_heads/sabl_head.py index 9f91e13..e304c0c 100644 --- a/mmdet2trt/models/roi_heads/bbox_heads/sabl_head.py +++ b/mmdet2trt/models/roi_heads/bbox_heads/sabl_head.py @@ -1,12 +1,11 @@ import torch import torch.nn.functional as F - -from mmdet2trt.models.builder import register_wraper +from mmdet2trt.models.builder import register_wrapper from .bbox_head import BBoxHeadWraper -@register_wraper('mmdet.models.roi_heads.bbox_heads.sabl_head.SABLHead') +@register_wrapper('mmdet.models.roi_heads.bbox_heads.sabl_head.SABLHead') class SABLHeadWraper(BBoxHeadWraper): def __init__(self, module, test_cfg): diff --git a/mmdet2trt/models/roi_heads/cascade_roi_head.py b/mmdet2trt/models/roi_heads/cascade_roi_head.py index 00c0bad..013d73e 100644 --- a/mmdet2trt/models/roi_heads/cascade_roi_head.py +++ b/mmdet2trt/models/roi_heads/cascade_roi_head.py @@ -1,13 +1,12 @@ -import torch -from mmdet.core.bbox.coder.delta_xywh_bbox_coder import delta2bbox -from torch import nn - import mmdet2trt.ops.util_ops as mm2trt_util +import torch from mmdet2trt.core.post_processing import merge_aug_masks -from mmdet2trt.models.builder import build_wraper, register_wraper +from mmdet2trt.models.builder import build_wrapper, register_wrapper +from mmdet.models.task_modules.coders.delta_xywh_bbox_coder import delta2bbox +from torch import nn -@register_wraper('mmdet.models.roi_heads.CascadeRoIHead') +@register_wrapper('mmdet.models.roi_heads.CascadeRoIHead') class CascadeRoIHeadWraper(nn.Module): def __init__(self, module, wrap_config): @@ -16,10 +15,10 @@ def __init__(self, module, wrap_config): self.wrap_config = wrap_config self.bbox_roi_extractor = [ - build_wraper(extractor) for extractor in module.bbox_roi_extractor + build_wrapper(extractor) for extractor in module.bbox_roi_extractor ] self.bbox_head = [ - build_wraper(bb_head, test_cfg=module.test_cfg) + build_wrapper(bb_head, test_cfg=module.test_cfg) for bb_head in module.bbox_head ] if module.with_shared_head: @@ -40,10 +39,11 @@ def __init__(self, module, wrap_config): def init_mask_head(self, mask_roi_extractor, mask_head): self.mask_roi_extractor = [ - build_wraper(mre) for mre in mask_roi_extractor + build_wrapper(mre) for mre in mask_roi_extractor ] self.mask_head = [ - build_wraper(mh, test_cfg=self.module.test_cfg) for mh in mask_head + build_wrapper(mh, test_cfg=self.module.test_cfg) + for mh in mask_head ] def _bbox_forward(self, stage, x, rois): diff --git a/mmdet2trt/models/roi_heads/double_roi_head.py b/mmdet2trt/models/roi_heads/double_roi_head.py index 54b4a09..f8f13e1 100644 --- a/mmdet2trt/models/roi_heads/double_roi_head.py +++ b/mmdet2trt/models/roi_heads/double_roi_head.py @@ -1,9 +1,9 @@ -from mmdet2trt.models.builder import register_wraper +from mmdet2trt.models.builder import register_wrapper from .standard_roi_head import StandardRoIHeadWraper -@register_wraper('mmdet.models.roi_heads.double_roi_head.DoubleHeadRoIHead') +@register_wrapper('mmdet.models.roi_heads.double_roi_head.DoubleHeadRoIHead') class DoubleHeadRoIHeadWraper(StandardRoIHeadWraper): def __init__(self, module, wrap_config): diff --git a/mmdet2trt/models/roi_heads/grid_roi_head.py b/mmdet2trt/models/roi_heads/grid_roi_head.py index 3ea0fe1..85970af 100644 --- a/mmdet2trt/models/roi_heads/grid_roi_head.py +++ b/mmdet2trt/models/roi_heads/grid_roi_head.py @@ -1,20 +1,19 @@ -import torch - import mmdet2trt.ops.util_ops as mm2trt_util -from mmdet2trt.core.bbox.transforms import bbox2roi -from mmdet2trt.models.builder import build_wraper, register_wraper +import torch +from mmdet2trt.models.builder import build_wrapper, register_wrapper +from mmdet2trt.structures.bbox.transforms import bbox2roi from .standard_roi_head import StandardRoIHeadWraper -@register_wraper('mmdet.models.roi_heads.grid_roi_head.GridRoIHead') +@register_wrapper('mmdet.models.roi_heads.grid_roi_head.GridRoIHead') class GridRoIHeadWraper(StandardRoIHeadWraper): def __init__(self, module, wrap_config): super(GridRoIHeadWraper, self).__init__(module, wrap_config) - self.grid_roi_extractor = build_wraper(module.grid_roi_extractor) - self.grid_head = build_wraper( + self.grid_roi_extractor = build_wrapper(module.grid_roi_extractor) + self.grid_head = build_wrapper( module.grid_head, test_cfg=module.test_cfg) def forward(self, feat, proposals, img_shape): diff --git a/mmdet2trt/models/roi_heads/htc_roi_head.py b/mmdet2trt/models/roi_heads/htc_roi_head.py index bf862a3..d7a985e 100644 --- a/mmdet2trt/models/roi_heads/htc_roi_head.py +++ b/mmdet2trt/models/roi_heads/htc_roi_head.py @@ -1,15 +1,14 @@ +import mmdet2trt.ops.util_ops as mm2trt_util import torch import torch.nn.functional as F -from mmdet.core.bbox.coder.delta_xywh_bbox_coder import delta2bbox - -import mmdet2trt.ops.util_ops as mm2trt_util from mmdet2trt.core.post_processing import merge_aug_masks -from mmdet2trt.models.builder import build_wraper, register_wraper +from mmdet2trt.models.builder import build_wrapper, register_wrapper +from mmdet.models.task_modules.coders.delta_xywh_bbox_coder import delta2bbox from .cascade_roi_head import CascadeRoIHeadWraper -@register_wraper('mmdet.models.roi_heads.HybridTaskCascadeRoIHead') +@register_wrapper('mmdet.models.roi_heads.HybridTaskCascadeRoIHead') class HybridTaskCascadeRoIHeadWraper(CascadeRoIHeadWraper): def __init__(self, module, wrap_config): @@ -19,7 +18,7 @@ def __init__(self, module, wrap_config): module = self.module self.semantic_head = None if module.semantic_head is not None: - self.semantic_roi_extractor = build_wraper( + self.semantic_roi_extractor = build_wrapper( module.semantic_roi_extractor) self.semantic_head = module.semantic_head diff --git a/mmdet2trt/models/roi_heads/mask_heads/fcn_mask_head.py b/mmdet2trt/models/roi_heads/mask_heads/fcn_mask_head.py index 752d223..f99918f 100644 --- a/mmdet2trt/models/roi_heads/mask_heads/fcn_mask_head.py +++ b/mmdet2trt/models/roi_heads/mask_heads/fcn_mask_head.py @@ -1,11 +1,11 @@ import torch import torch.nn.functional as F +from mmdet2trt.models.builder import register_wrapper from torch import nn -from mmdet2trt.models.builder import register_wraper - -@register_wraper('mmdet.models.roi_heads.mask_heads.fcn_mask_head.FCNMaskHead') +@register_wrapper('mmdet.models.roi_heads.mask_heads.fcn_mask_head.FCNMaskHead' + ) class FCNMaskHeadWraper(nn.Module): def __init__(self, module, test_cfg): diff --git a/mmdet2trt/models/roi_heads/mask_heads/grid_head.py b/mmdet2trt/models/roi_heads/mask_heads/grid_head.py index 90da07e..574c6eb 100644 --- a/mmdet2trt/models/roi_heads/mask_heads/grid_head.py +++ b/mmdet2trt/models/roi_heads/mask_heads/grid_head.py @@ -1,10 +1,9 @@ import torch +from mmdet2trt.models.builder import register_wrapper from torch import nn -from mmdet2trt.models.builder import register_wraper - -@register_wraper('mmdet.models.roi_heads.mask_heads.grid_head.GridHead') +@register_wrapper('mmdet.models.roi_heads.mask_heads.grid_head.GridHead') class GridHeadWraper(nn.Module): def __init__(self, module, test_cfg): diff --git a/mmdet2trt/models/roi_heads/mask_heads/htc_mask_head.py b/mmdet2trt/models/roi_heads/mask_heads/htc_mask_head.py index e543bce..c39f9d2 100644 --- a/mmdet2trt/models/roi_heads/mask_heads/htc_mask_head.py +++ b/mmdet2trt/models/roi_heads/mask_heads/htc_mask_head.py @@ -1,9 +1,9 @@ +from mmdet2trt.models.builder import register_wrapper from torch import nn -from mmdet2trt.models.builder import register_wraper - -@register_wraper('mmdet.models.roi_heads.mask_heads.htc_mask_head.HTCMaskHead') +@register_wrapper('mmdet.models.roi_heads.mask_heads.htc_mask_head.HTCMaskHead' + ) class HTCMaskHeadWraper(nn.Module): def __init__(self, module, test_cfg): diff --git a/mmdet2trt/models/roi_heads/roi_extractors/generic_roi_extractor.py b/mmdet2trt/models/roi_heads/roi_extractors/generic_roi_extractor.py index 8d976e1..962d07b 100644 --- a/mmdet2trt/models/roi_heads/roi_extractors/generic_roi_extractor.py +++ b/mmdet2trt/models/roi_heads/roi_extractors/generic_roi_extractor.py @@ -1,10 +1,9 @@ +from mmdet2trt.models.builder import register_wrapper from torch import nn -from mmdet2trt.models.builder import register_wraper - -@register_wraper('mmdet.models.roi_heads.roi_extractors' - '.generic_roi_extractor.GenericRoIExtractor') +@register_wrapper('mmdet.models.roi_heads.roi_extractors' + '.generic_roi_extractor.GenericRoIExtractor') class GenericRoIExtractorWraper(nn.Module): def __init__(self, module): diff --git a/mmdet2trt/models/roi_heads/roi_extractors/pooling_layers/deform_roi_pool_extractor.py b/mmdet2trt/models/roi_heads/roi_extractors/pooling_layers/deform_roi_pool_extractor.py index 0f12630..26d2591 100644 --- a/mmdet2trt/models/roi_heads/roi_extractors/pooling_layers/deform_roi_pool_extractor.py +++ b/mmdet2trt/models/roi_heads/roi_extractors/pooling_layers/deform_roi_pool_extractor.py @@ -1,13 +1,13 @@ -import mmcv.ops import torch +from mmdet2trt.models.builder import build_wrapper, register_wrapper from torch import nn -from mmdet2trt.models.builder import build_wraper, register_wraper +import mmcv.ops deformable_roi_pool_wrap = mmcv.ops.deform_roi_pool -@register_wraper('mmcv.ops.DeformRoIPoolPack') +@register_wrapper('mmcv.ops.DeformRoIPoolPack') class DeformRoIPoolPackWraper(nn.Module): def __init__(self, module): @@ -32,7 +32,7 @@ def forward(self, input, rois): self.module.gamma) -@register_wraper('mmcv.ops.ModulatedDeformRoIPoolPack') +@register_wrapper('mmcv.ops.ModulatedDeformRoIPoolPack') class ModulatedDeformRoIPoolPackWraper(nn.Module): def __init__(self, module): @@ -67,7 +67,7 @@ def __init__(self, module): self.module = module self.roi_layers = [ - build_wraper(layer) for layer in self.module.roi_layers + build_wrapper(layer) for layer in self.module.roi_layers ] self.featmap_strides = self.module.featmap_strides self.finest_scale = self.module.finest_scale diff --git a/mmdet2trt/models/roi_heads/roi_extractors/single_level_roi_extractor.py b/mmdet2trt/models/roi_heads/roi_extractors/single_level_roi_extractor.py index c53528a..60decf6 100644 --- a/mmdet2trt/models/roi_heads/roi_extractors/single_level_roi_extractor.py +++ b/mmdet2trt/models/roi_heads/roi_extractors/single_level_roi_extractor.py @@ -1,12 +1,11 @@ -from torch import nn - -from mmdet2trt.models.builder import register_wraper +from mmdet2trt.models.builder import register_wrapper from mmdet2trt.models.roi_heads.roi_extractors.pooling_layers import \ build_roi_extractor +from torch import nn -@register_wraper('mmdet.models.roi_heads.roi_extractors' - '.single_level_roi_extractor.SingleRoIExtractor') +@register_wrapper('mmdet.models.roi_heads.roi_extractors' + '.single_level_roi_extractor.SingleRoIExtractor') class SingleRoIExtractorWraper(nn.Module): def __init__(self, module): diff --git a/mmdet2trt/models/roi_heads/standard_roi_head.py b/mmdet2trt/models/roi_heads/standard_roi_head.py index bfe3baf..f9eab15 100644 --- a/mmdet2trt/models/roi_heads/standard_roi_head.py +++ b/mmdet2trt/models/roi_heads/standard_roi_head.py @@ -1,14 +1,13 @@ +import mmdet2trt.ops.util_ops as mm2trt_util import torch +from mmdet2trt.models.builder import build_wrapper, register_wrapper from torch import nn -import mmdet2trt.ops.util_ops as mm2trt_util -from mmdet2trt.models.builder import build_wraper, register_wraper - -@register_wraper( +@register_wrapper( 'mmdet.models.roi_heads.mask_scoring_roi_head.MaskScoringRoIHead') -@register_wraper('mmdet.models.roi_heads.dynamic_roi_head.DynamicRoIHead') -@register_wraper('mmdet.models.roi_heads.standard_roi_head.StandardRoIHead') +@register_wrapper('mmdet.models.roi_heads.dynamic_roi_head.DynamicRoIHead') +@register_wrapper('mmdet.models.roi_heads.standard_roi_head.StandardRoIHead') class StandardRoIHeadWraper(nn.Module): def __init__(self, module, wrap_config={}): @@ -16,9 +15,9 @@ def __init__(self, module, wrap_config={}): self.module = module self.wrap_config = wrap_config - self.bbox_roi_extractor = build_wraper(module.bbox_roi_extractor) + self.bbox_roi_extractor = build_wrapper(module.bbox_roi_extractor) - self.bbox_head = build_wraper( + self.bbox_head = build_wrapper( module.bbox_head, test_cfg=module.test_cfg) if module.with_shared_head: self.shared_head = module.shared_head @@ -35,8 +34,9 @@ def __init__(self, module, wrap_config={}): self.test_cfg = module.test_cfg def init_mask_head(self, mask_roi_extractor, mask_head): - self.mask_roi_extractor = build_wraper(mask_roi_extractor) - self.mask_head = build_wraper(mask_head, test_cfg=self.module.test_cfg) + self.mask_roi_extractor = build_wrapper(mask_roi_extractor) + self.mask_head = build_wrapper( + mask_head, test_cfg=self.module.test_cfg) def _bbox_forward(self, x, rois): bbox_feats = self.bbox_roi_extractor( diff --git a/mmdet2trt/models/task_modules/__init__.py b/mmdet2trt/models/task_modules/__init__.py new file mode 100644 index 0000000..345edbb --- /dev/null +++ b/mmdet2trt/models/task_modules/__init__.py @@ -0,0 +1,2 @@ +from . import coders # noqa: F401,F403 +from . import prior_generators # noqa: F401,F403 diff --git a/mmdet2trt/core/bbox/coder/__init__.py b/mmdet2trt/models/task_modules/coders/__init__.py similarity index 100% rename from mmdet2trt/core/bbox/coder/__init__.py rename to mmdet2trt/models/task_modules/coders/__init__.py diff --git a/mmdet2trt/core/bbox/coder/bucketing_bbox_coder.py b/mmdet2trt/models/task_modules/coders/bucketing_bbox_coder.py similarity index 96% rename from mmdet2trt/core/bbox/coder/bucketing_bbox_coder.py rename to mmdet2trt/models/task_modules/coders/bucketing_bbox_coder.py index b9c0a4a..be6016b 100644 --- a/mmdet2trt/core/bbox/coder/bucketing_bbox_coder.py +++ b/mmdet2trt/models/task_modules/coders/bucketing_bbox_coder.py @@ -1,10 +1,9 @@ import numpy as np import torch +from mmdet2trt.models.builder import register_wrapper from torch import nn from torch.nn import functional as F -from mmdet2trt.models.builder import register_wraper - from .transforms import bbox_rescale_batched @@ -70,7 +69,7 @@ def bucket2bbox_batched(proposals, return bboxes, loc_confidence -@register_wraper('mmdet.core.bbox.coder.BucketingBBoxCoder') +@register_wrapper('mmdet.models.task_modules.coders.BucketingBBoxCoder') class BucketingBBoxCoderWraper(nn.Module): def __init__(self, module): diff --git a/mmdet2trt/core/bbox/coder/delta_xywh_bbox_coder.py b/mmdet2trt/models/task_modules/coders/delta_xywh_bbox_coder.py similarity index 97% rename from mmdet2trt/core/bbox/coder/delta_xywh_bbox_coder.py rename to mmdet2trt/models/task_modules/coders/delta_xywh_bbox_coder.py index fb690a1..4c619df 100644 --- a/mmdet2trt/core/bbox/coder/delta_xywh_bbox_coder.py +++ b/mmdet2trt/models/task_modules/coders/delta_xywh_bbox_coder.py @@ -1,9 +1,8 @@ import numpy as np import torch +from mmdet2trt.models.builder import register_wrapper from torch import nn -from mmdet2trt.models.builder import register_wraper - def delta2bbox_custom_func(cls_scores, bbox_preds, @@ -83,7 +82,7 @@ def delta2bbox_batched(rois, return bboxes -@register_wraper('mmdet.core.bbox.coder.DeltaXYWHBBoxCoder') +@register_wrapper('mmdet.models.task_modules.coders.DeltaXYWHBBoxCoder') class DeltaXYWHBBoxCoderWraper(nn.Module): def __init__(self, module): diff --git a/mmdet2trt/core/bbox/coder/tblr_bbox_coder.py b/mmdet2trt/models/task_modules/coders/tblr_bbox_coder.py similarity index 95% rename from mmdet2trt/core/bbox/coder/tblr_bbox_coder.py rename to mmdet2trt/models/task_modules/coders/tblr_bbox_coder.py index 741f2c7..99d9533 100644 --- a/mmdet2trt/core/bbox/coder/tblr_bbox_coder.py +++ b/mmdet2trt/models/task_modules/coders/tblr_bbox_coder.py @@ -1,8 +1,7 @@ import torch -from torch import nn - -from mmdet2trt.models.builder import register_wraper +from mmdet2trt.models.builder import register_wrapper from mmdet2trt.ops import util_ops +from torch import nn def batched_blr2bboxes(priors, @@ -35,7 +34,7 @@ def batched_blr2bboxes(priors, return boxes -@register_wraper('mmdet.core.bbox.coder.TBLRBBoxCoder') +@register_wrapper('mmdet.models.task_modules.coders.TBLRBBoxCoder') class TBLRBBoxCoderWraper(nn.Module): def __init__(self, module): diff --git a/mmdet2trt/core/bbox/coder/transforms.py b/mmdet2trt/models/task_modules/coders/transforms.py similarity index 100% rename from mmdet2trt/core/bbox/coder/transforms.py rename to mmdet2trt/models/task_modules/coders/transforms.py diff --git a/mmdet2trt/core/bbox/coder/yolo_bbox_coder.py b/mmdet2trt/models/task_modules/coders/yolo_bbox_coder.py similarity index 90% rename from mmdet2trt/core/bbox/coder/yolo_bbox_coder.py rename to mmdet2trt/models/task_modules/coders/yolo_bbox_coder.py index 07df3f2..2352657 100644 --- a/mmdet2trt/core/bbox/coder/yolo_bbox_coder.py +++ b/mmdet2trt/models/task_modules/coders/yolo_bbox_coder.py @@ -1,8 +1,7 @@ import torch +from mmdet2trt.models.builder import register_wrapper from torch import nn -from mmdet2trt.models.builder import register_wraper - def yolodecoder_batched(bboxes, pred_bboxes, stride): x_center = (bboxes[..., 0] + bboxes[..., 2]) * 0.5 @@ -23,7 +22,7 @@ def yolodecoder_batched(bboxes, pred_bboxes, stride): return decoded_bboxes -@register_wraper('mmdet.core.bbox.coder.YOLOBBoxCoder') +@register_wrapper('mmdet.models.task_modules.coders.YOLOBBoxCoder') class YOLOBBoxCoderWraper(nn.Module): def __init__(self, module): diff --git a/mmdet2trt/core/anchor/__init__.py b/mmdet2trt/models/task_modules/prior_generators/__init__.py similarity index 100% rename from mmdet2trt/core/anchor/__init__.py rename to mmdet2trt/models/task_modules/prior_generators/__init__.py diff --git a/mmdet2trt/core/anchor/anchor_generator.py b/mmdet2trt/models/task_modules/prior_generators/anchor_generator.py similarity index 88% rename from mmdet2trt/core/anchor/anchor_generator.py rename to mmdet2trt/models/task_modules/prior_generators/anchor_generator.py index be7c2f5..8413553 100644 --- a/mmdet2trt/core/anchor/anchor_generator.py +++ b/mmdet2trt/models/task_modules/prior_generators/anchor_generator.py @@ -1,7 +1,6 @@ +from mmdet2trt.models.builder import register_wrapper from torch import nn -from mmdet2trt.models.builder import register_wraper - class AnchorGeneratorSingle(nn.Module): @@ -32,8 +31,10 @@ def forward(self, x, stride=None, device='cuda'): device=device) -@register_wraper('mmdet.core.anchor.anchor_generator.YOLOAnchorGenerator') -@register_wraper('mmdet.core.AnchorGenerator') +@register_wrapper('mmdet.models.task_modules.prior_generators' + '.YOLOAnchorGenerator') +@register_wrapper('mmdet.models.task_modules.prior_generators' + '.AnchorGenerator') class AnchorGeneratorWraper(nn.Module): def __init__(self, module): @@ -65,7 +66,8 @@ def forward(self, feat_list, device='cuda'): return multi_level_anchors -@register_wraper('mmdet.core.anchor.anchor_generator.SSDAnchorGenerator') +@register_wrapper('mmdet.models.task_modules.prior_generators' + '.SSDAnchorGenerator') class SSDAnchorGeneratorWraper(nn.Module): def __init__(self, module): diff --git a/mmdet2trt/core/anchor/point_generator.py b/mmdet2trt/models/task_modules/prior_generators/point_generator.py similarity index 91% rename from mmdet2trt/core/anchor/point_generator.py rename to mmdet2trt/models/task_modules/prior_generators/point_generator.py index 78ea0d9..f3cbeb3 100644 --- a/mmdet2trt/core/anchor/point_generator.py +++ b/mmdet2trt/models/task_modules/prior_generators/point_generator.py @@ -1,11 +1,10 @@ +import mmdet2trt import torch +from mmdet2trt.models.builder import register_wrapper from torch import nn -import mmdet2trt -from mmdet2trt.models.builder import register_wraper - -@register_wraper('mmdet.core.anchor.point_generator.PointGenerator') +@register_wrapper('mmdet.models.task_modules.prior_generators.PointGenerator') class PointGeneratorWraper(nn.Module): def __init__(self, module): @@ -22,7 +21,8 @@ def forward(self, featmap, stride): return shifts -@register_wraper('mmdet.core.anchor.point_generator.MlvlPointGenerator') +@register_wrapper('mmdet.models.task_modules.prior_generators' + '.MlvlPointGenerator') class MlvlPointGeneratorWraper(nn.Module): def __init__(self, module): diff --git a/mmdet2trt/structures/__init__.py b/mmdet2trt/structures/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mmdet2trt/structures/bbox/__init__.py b/mmdet2trt/structures/bbox/__init__.py new file mode 100644 index 0000000..1a047fb --- /dev/null +++ b/mmdet2trt/structures/bbox/__init__.py @@ -0,0 +1,7 @@ +from .bbox_overlaps import bbox_overlaps_batched +from .transforms import batched_bbox_cxcywh_to_xyxy, batched_distance2bbox + +__all__ = [ + 'batched_distance2bbox', 'batched_bbox_cxcywh_to_xyxy', + 'bbox_overlaps_batched' +] diff --git a/mmdet2trt/core/bbox/iou_calculators/iou2d_calculator.py b/mmdet2trt/structures/bbox/bbox_overlaps.py similarity index 100% rename from mmdet2trt/core/bbox/iou_calculators/iou2d_calculator.py rename to mmdet2trt/structures/bbox/bbox_overlaps.py diff --git a/mmdet2trt/core/bbox/transforms.py b/mmdet2trt/structures/bbox/transforms.py similarity index 99% rename from mmdet2trt/core/bbox/transforms.py rename to mmdet2trt/structures/bbox/transforms.py index 2a1ab01..ea9976b 100644 --- a/mmdet2trt/core/bbox/transforms.py +++ b/mmdet2trt/structures/bbox/transforms.py @@ -1,6 +1,5 @@ -import torch - import mmdet2trt.ops.util_ops as mm2trt_util +import torch def batched_distance2bbox(points, distance, max_shape=None): diff --git a/setup.py b/setup.py index 9410ac2..6626a36 100644 --- a/setup.py +++ b/setup.py @@ -122,7 +122,7 @@ def run(self): setup( name='mmdet2trt', - version='0.5.0', + version='0.6.0', author='q.yao', author_email='streetyao@live.com', description='mmdetection to tensorrt converter', diff --git a/tests/model_test.py b/tests/model_test.py index 0485d34..948c284 100644 --- a/tests/model_test.py +++ b/tests/model_test.py @@ -3,13 +3,15 @@ import os.path as osp from argparse import ArgumentParser -import cv2 +import mmengine import torch import tqdm -from mmdet.apis import inference_detector - from mmdet2trt import mmdet2trt from mmdet2trt.apis import create_wrap_detector +from mmdet.apis import inference_detector +from mmengine.registry import init_default_scope + +import mmcv logging.basicConfig(level=logging.INFO) logger = logging.getLogger('mmdet2trt') @@ -44,25 +46,34 @@ def inference_test(trt_model, save_folder, score_thr=0.3): file_list = os.listdir(test_folder) + model_cfg = mmengine.Config.fromfile(cfg_path) + init_default_scope(model_cfg.get('default_scope', 'mmdet')) wrap_model = create_wrap_detector(trt_model, cfg_path, device) + from mmdet.registry import VISUALIZERS + visualizer_cfg = dict(type='DetLocalVisualizer', name='visualizer') + visualizer = VISUALIZERS.build(visualizer_cfg) + visualizer.dataset_meta = wrap_model.dataset_meta + for file_name in tqdm.tqdm(file_list): if not (file_name.lower().endswith('.jpg') or file_name.lower().endswith('.png')): continue image_path = osp.join(test_folder, file_name) - image = cv2.imread(image_path) + image = mmcv.imread(image_path) result = inference_detector(wrap_model, image) - wrap_model.show_result( - image, - result, - score_thr=score_thr, + visualizer.add_datasample( + 'result', + mmcv.imconvert(image, 'bgr', 'rgb'), + data_sample=result, + draw_gt=False, show=False, - out_file=osp.join(save_folder, file_name)) + out_file=osp.join(save_folder, file_name), + pred_score_thr=score_thr) TEST_MODE_DICT = {'convert': 1, 'inference': 1 << 1, 'all': 0b11} diff --git a/tools/collect_env.py b/tools/collect_env.py index 852a5d1..2991011 100644 --- a/tools/collect_env.py +++ b/tools/collect_env.py @@ -2,7 +2,6 @@ # https://github.com/pytorch/pytorch/blob/master/torch/utils/collect_env.py from __future__ import print_function - # Unlike the rest of the PyTorch this file must be python2 compliant. # This script outputs relevant system environment info # Run it with `python collect_env.py`. diff --git a/tools/test.py b/tools/test.py index 90c5494..b445dea 100644 --- a/tools/test.py +++ b/tools/test.py @@ -1,11 +1,11 @@ import argparse -import mmcv -from mmcv.parallel import MMDataParallel +from mmdet2trt.apis import create_wrap_detector from mmdet.apis import single_gpu_test from mmdet.datasets import build_dataloader, build_dataset -from mmdet2trt.apis import create_wrap_detector +import mmcv +from mmcv.parallel import MMDataParallel def parse_args():