Skip to content

Commit

Permalink
update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire committed Feb 18, 2024
1 parent b98cc0d commit 85a9267
Show file tree
Hide file tree
Showing 10 changed files with 90 additions and 90 deletions.
75 changes: 35 additions & 40 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
# MMDet to TensorRT

> [!NOTE]
>
> The main branch is used to support model conversion of MMDetection>=3.0.
> If you want to convert model on older MMDetection, Please switch to branch:
> - [mmdet2trt=v0.5.0](https://github.com/grimoire/mmdetection-to-tensorrt/tree/v0.5.0)
> - [torch2trt_dynamic=v0.5.0](https://github.com/grimoire/torch2trt_dynamic/tree/v0.5.0)
> - [amirstan_plugin=v0.5.0](https://github.com/grimoire/amirstan_plugin/tree/v0.5.0).
## News

- 24.02 Support MMDetection>=3.0
- 2024.02: Support MMDetection>=3.0

## Introduction

This project aims to convert the Model from MMDetection to TensorRT model end2end.
Focus on object detection for now.
This project aims to support End2End deployment of models in MMDetection with TensorRT.

Mask support is **experiment**.

Features:
Expand All @@ -17,9 +25,7 @@ Features:
- batched input
- dynamic input shape
- combination of different modules
- deepstream

Any advices, bug reports and stars are welcome.
- DeepStream

## Requirement

Expand Down Expand Up @@ -52,11 +58,13 @@ Any advices, bug reports and stars are welcome.
make -j10
```

- **DON'T FORGET** setting the envoirment variable(in `~/.bashrc`):
```bash
export AMIRSTAN_LIBRARY_PATH=${amirstan_plugin_root}/build/lib
```
> [!NOTE]
>
> **DON'T FORGET** setting the environment variable(in `~/.bashrc`):
>
> ```bash
> export AMIRSTAN_LIBRARY_PATH=${amirstan_plugin_root}/build/lib
> ```
## Installation
Expand All @@ -73,17 +81,9 @@ pip install -e .
Build docker image
```bash
# cuda11.1 TensorRT7.2.2 pytorch1.8 cuda11.1
sudo docker build -t mmdet2trt_docker:v1.0 docker/
```
You can also specify CUDA, Pytorch and Torchvision versions with docker build args by:
```bash
# cuda11.1 tensorrt7.2.2 pytorch1.6 cuda10.2
sudo docker build -t mmdet2trt_docker:v1.0 --build-arg TORCH_VERSION=1.6.0 --build-arg TORCHVISION_VERSION=0.7.0 --build-arg CUDA=10.2 --docker/
```
Run (will show the help for the CLI entrypoint)
```bash
Expand All @@ -104,12 +104,13 @@ sudo docker run --gpus all -it --rm -v ${your_data_path}:${bind_path} mmdet2trt_
## Usage
how to create a TensorRT model from mmdet model (converting might take few minutes)(Might have some warning when converting.)
Create a TensorRT model from mmdet model.
detail can be found in [getting_started.md](./docs/getting_started.md)
### CLI
```bash
# conversion might take few minutes.
mmdet2trt ${CONFIG_PATH} ${CHECKPOINT_PATH} ${OUTPUT_PATH}
```
Expand All @@ -118,15 +119,17 @@ Run mmdet2trt -h for help on optional arguments.
### Python
```python
opt_shape_param=[
[
[1,3,320,320], # min shape
[1,3,800,1344], # optimize shape
[1,3,1344,1344], # max shape
]
]
max_workspace_size=1<<30 # some module and tactic need large workspace.
trt_model = mmdet2trt(cfg_path, weight_path, opt_shape_param=opt_shape_param, fp16_mode=True, max_workspace_size=max_workspace_size)
shape_ranges=dict(
x=dict(
min=[1,3,320,320],
opt=[1,3,800,1344],
max=[1,3,1344,1344],
)
)
trt_model = mmdet2trt(cfg_path,
weight_path,
shape_ranges=shape_ranges,
fp16_mode=True)
# save converted model
torch.save(trt_model.state_dict(), save_model_path)
Expand All @@ -136,13 +139,13 @@ with open(save_engine_path, mode='wb') as f:
f.write(trt_model.state_dict()['engine'])
```
> \[!NOTE\]
> [!NOTE]
>
> The input of the engine is the tensor **after preprocess**.
> The output of the engine is `num_dets, bboxes, scores, class_ids`. if you enable the `enable_mask` flag, there will be another output `mask`.
> The bboxes output of the engine did not divided by `scale factor`.
> The bboxes output of the engine did not divided by `scale_factor`.
how to use the converted model
how to perform inference with the converted model.
```python
from mmdet.apis import inference_detector
Expand All @@ -153,14 +156,6 @@ trt_detector = create_wrap_detector(trt_model, cfg_path, device_id)
# result share same format as mmdetection
result = inference_detector(trt_detector, image_path)
# visualize
trt_detector.show_result(
image_path,
result,
score_thr=score_thr,
win_name='mmdet2trt',
show=True)
```
Try demo in `demo/inference.py`, or `demo/cpp` if you want to do inference with c++ api.
Expand Down
9 changes: 5 additions & 4 deletions demo/cpp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,26 @@
Best to be used within built docker container (use provided in the project Dockerfile to build the image).

## Requirements

The sample needs additional installation of opencv:

```
```bash
apt-get install -y libopencv-dev
```

## Install

Within <mmdetection-to-trt-root/demo/cpp>

```
```bash
mkdir build & cd build
cmake -Damirstan_plugin_root=<path-to-amirstan_plugin-root> ..
make -j4
```

## Run the sample

```
```bash
build/trt_sample <serialized model filepath (.engine)> <test image(-s) paths>
```

Expand All @@ -31,7 +32,7 @@ The sample is implemented for the TensorRT model converted from mmdetection DCNv

To obtain converted model and serialized built TensorRT engine run following command within provided docker container (~/space folder):

```
```bash
mmdet2trt --save-engine=true \
--min-scale 1 3 320 320 \
--opt-scale 1 3 544 960 \
Expand Down
22 changes: 16 additions & 6 deletions demo/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from mmdet2trt import mmdet2trt
from mmdet2trt.apis import create_wrap_detector
from mmdet.apis import inference_detector
from mmdet.registry import VISUALIZERS

import mmcv


def main():
Expand Down Expand Up @@ -32,12 +35,19 @@ def main():

result = inference_detector(trt_detector, image_path)

trt_detector.show_result(
image_path,
result,
score_thr=args.score_thr,
win_name='mmdet2trt_demo',
show=True)
# visualize
visualizer_cfg = dict(type='DetLocalVisualizer', name='visualizer')
visualizer = VISUALIZERS.build(visualizer_cfg)
visualizer.dataset_meta = trt_detector.dataset_meta

image = mmcv.imread(image_path)
visualizer.add_datasample(
'result',
mmcv.imconvert(image, 'bgr', 'rgb'),
data_sample=result,
draw_gt=False,
show=True,
pred_score_thr=args.score_thr)


if __name__ == '__main__':
Expand Down
20 changes: 6 additions & 14 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
FROM nvcr.io/nvidia/tensorrt:20.12-py3
FROM nvcr.io/nvidia/tensorrt:24.01-py3

ARG CUDA=11.1
ARG TORCH_VERSION=1.8.0
ARG TORCHVISION_VERSION=0.9.0
ARG CUDA=12.1
ARG TORCH_VERSION=2.2.0
ARG TORCHVISION_VERSION=0.17.0

ENV DEBIAN_FRONTEND=noninteractive

Expand All @@ -22,16 +22,8 @@ RUN pip3 install torch==${TORCH_VERSION}+cu${CUDA//./} torchvision==${TORCHVISIO


### install mmcv
RUN pip3 install pytest-runner
RUN pip3 install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu${CUDA//./}/torch${TORCH_VERSION}/index.html

### git mmdetection
RUN git clone --depth=1 https://github.com/open-mmlab/mmdetection.git /root/space/mmdetection

### install mmdetection
RUN cd /root/space/mmdetection &&\
pip3 install -r requirements.txt &&\
python3 setup.py develop
RUN pip3 install openmim &&\
mim install mmdet==3.3.0

### git amirstan plugin
RUN git clone --depth=1 https://github.com/grimoire/amirstan_plugin.git /root/space/amirstan_plugin &&\
Expand Down
40 changes: 21 additions & 19 deletions docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,26 @@ This page provides details about mmdet2trt.

## dynamic shape/batched input

`opt_shape_param` is used to set the min/optimize/max shape of the input tensor. For each dimension in it, min<=optimize<=max. For example:
`shape_ranges` is used to set the min/optimize/max shape of the input tensor. For each dimension in it, min<=optimize<=max. For example:

```python
opt_shape_param=[
[
[1,3,320,320], # min shape
[2,3,800,1312], # opt shape
[4,3,1344,1344], # max shape
]
]
shape_ranges=dict(
x=dict(
min=[1,3,320,320],
opt=[1,3,800,1344],
max=[1,3,1344,1344],
)
)
trt_model = mmdet2trt( ...,
opt_shape_param=opt_shape_param, # set the opt shape
shape_ranges=shape_ranges, # set the opt shape
...)
```

This config will give you input tensor size between (320, 320) to (1344, 1344), max batch_size=4

**Warning:**

Dynamic input shape and batch support might need more memory. Use fixed shape to avoid unnecessary memory usage(min=optimize=max).
> [!WARNING]
>
> Dynamic input shape and batch support might need more memory. Use fixed shape to avoid unnecessary memory usage(min=optimize=max).
## fp16 support

Expand Down Expand Up @@ -72,9 +72,9 @@ trt_model = mmdet2trt(cfg_path, model_path,
int8_calib_alg="entropy")
```

**Warning:**

Not all models support int8 mode.
> [!WARNING]
>
> Not all models support int8 mode.
## max workspace size

Expand All @@ -92,9 +92,9 @@ with open(engine_path, mode='wb') as f:

Link the `${AMIRSTAN_PLUGIN_DIR}/build/lib/libamirstan_plugin.so` in your project (or you can load it in runtime). Compile and load the engine.

**Warning:**

might need to invoke `initLibAmirstanInferPlugins()` in [amirInferPlugin.h](https://github.com/grimoire/amirstan_plugin/blob/master/include/plugin/amirInferPlugin.h) to load the plugins.
> [!WARNING]
>
> might need to invoke `initLibAmirstanInferPlugins()` in [amirInferPlugin.h](https://github.com/grimoire/amirstan_plugin/blob/master/include/plugin/amirInferPlugin.h) to load the plugins.
The engine only contains inference forward. Preprocess(resize, normalize) and postprocess (divide scale factor) should be done in your project.

Expand Down Expand Up @@ -162,4 +162,6 @@ set flag `enable_mask` to True
trt_model = mmdet2trt(... , enable_mask = True)
```

**Note**: the mask output is of shape `[batch_size, num_boxes, 28, 28]`, the post-process of masks have not been included in the model. Please implement it by yourself if you want to integrate the converted engine into your own project.
> [!NOTE]
>
> The mask output is of shape `[batch_size, num_boxes, 28, 28]`, the post-process of masks have not been included in the model. Please implement it by yourself if you want to integrate the converted engine into your own project.
4 changes: 2 additions & 2 deletions mmdet2trt/models/dense_heads/anchor_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class AnchorHeadWraper(nn.Module):
def __init__(self, module):
super(AnchorHeadWraper, self).__init__()
self.module = module
self.anchor_generator = build_wrapper(self.module.anchor_generator)
self.prior_generator = build_wrapper(self.module.prior_generator)
self.bbox_coder = build_wrapper(self.module.bbox_coder)

self.test_cfg = module.test_cfg
Expand All @@ -32,7 +32,7 @@ def forward(self, feat, x):

cls_scores, bbox_preds = module(feat)

mlvl_anchors = self.anchor_generator(
mlvl_anchors = self.prior_generator(
cls_scores, device=cls_scores[0].device)

mlvl_scores = []
Expand Down
2 changes: 1 addition & 1 deletion mmdet2trt/models/dense_heads/atss_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def forward(self, feat, x):

cls_scores, bbox_preds, centernesses = module(feat)

mlvl_anchors = self.anchor_generator(
mlvl_anchors = self.prior_generator(
cls_scores, device=cls_scores[0].device)

mlvl_scores = []
Expand Down
2 changes: 1 addition & 1 deletion mmdet2trt/models/dense_heads/cascade_rpn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def forward(self, x, offset_list):
return self.module(x, offset_list)

def get_anchors(self, featmaps, device='cuda'):
return self.anchor_generator(featmaps, device=device)
return self.prior_generator(featmaps, device=device)

def anchor_offset(self, anchor_list, anchor_strides, featmap_sizes):

Expand Down
4 changes: 2 additions & 2 deletions mmdet2trt/models/dense_heads/gfl_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def forward(self, feat, x):
cls_scores, bbox_preds = module(feat)

num_levels = len(cls_scores)
mlvl_anchors = self.anchor_generator(
mlvl_anchors = self.prior_generator(
cls_scores, device=cls_scores[0].device)

mlvl_scores = []
Expand All @@ -58,7 +58,7 @@ def forward(self, feat, x):
rpn_cls_score = cls_scores[idx]
rpn_bbox_pred = bbox_preds[idx]
anchors = mlvl_anchors[idx]
stride = module.anchor_generator.strides[idx]
stride = module.prior_generator.strides[idx]
scores = rpn_cls_score.permute(0, 2, 3, 1).reshape(
rpn_cls_score.shape[0], -1, module.cls_out_channels).sigmoid()
bbox_pred = rpn_bbox_pred.permute(0, 2, 3, 1)
Expand Down
2 changes: 1 addition & 1 deletion mmdet2trt/models/dense_heads/paa_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def forward(self, feat, x):

cls_scores, bbox_preds, iou_preds = module(feat)

mlvl_anchors = self.anchor_generator(
mlvl_anchors = self.prior_generator(
cls_scores, device=cls_scores[0].device)

mlvl_scores = []
Expand Down

0 comments on commit 85a9267

Please sign in to comment.