Skip to content

Implementation of CenterNet and FairMOT with PyTorch Lightning

License

Notifications You must be signed in to change notification settings

gau-nernst/centernet-lightning

Repository files navigation

CenterNet

CenterNet is a strong single-stage, single-scale, and anchor-free object detector. This implementation is built with PyTorch Lightning, supports TorchScript and ONNX export, and has modular design to make customizing components simple.

References

To read more about the architecture and code structure of this implementation, see implementation.md

Install

Dependencies

conda install pytorch torchvision cudatoolkit=11.3 -c pytorch
pip install pytorch-lightning pycocotools albumentations
pip install git+https://github.com/gau-nernst/vision-toolbox.git            # backbones and necks
pip install filterpy git+https://github.com/JonathonLuiten/TrackEval.git    # for FairMOT
pip install jsonargparse[signatures]                                        # for training

Inference

Create a CenterNet model

Import build_centernet from models to build a CenterNet model from a YAML file. Sample config files are provided in the configs/ directory.

from centernet_lightning.models import build_centernet

model = build_centernet("configs/coco_resnet34.yaml")

You also can load a CenterNet model directly from a checkpoint thanks to PyTorch Lightning.

from centernet_lightning.models import CenterNet

model = CenterNet.load_from_checkpoint("path/to/checkpoint.ckpt")

Folder of images

Use CenterNet.inference_detection() or CenterNet.inference_tracking()

model = ...     # create a model as above
img_dir = "path/to/img/dir"
detections = model.inference_detection(img_dir, num_detections=100)

detections is a dictionary with the following keys:

Key Description Shape
bboxes bounding boxes in x1y1x2y2 format (num_images x num_detections x 4)
labels class labels (num_images x num_detections)
scores confidence scores (num_images x num_detections)

Results are np.ndarray, ready for post-processing.

Single image

This is useful when you use CenterNet in your own applications

import numpy as np
import torch
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2

# read image
img = cv2.imread("path/to/image")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

# apply pre-processing: resize to 512x512 and normalize with ImageNet statistics
# use torchvision.transforms should work also
transforms = A.Compose([
    A.Resize(height=512, width=512),
    A.Normalize(),
    ToTensorV2()
])
img = transforms(image=img)["image"]

# create a model as above and put it in evaluation mode
model = ...     
model.eval()

# turn off gradient calculation and do forward pass
with torch.no_grad():
    encoded_outputs = model(img.unsqueeze(0))
    detections = model.gather_detection2d(encoded_outputs)

detections has the same format as above, but the values are torch.Tensor.

Note: Due to data augmentations during training, the model is robust enough to not need ImageNet normalization in inference. You can normalize input image to [0,1] and CenterNet should still work fine.

Deployment

CenterNet is export-friendly. You can directly export a trained model to ONNX or TorchScript (only tracing) using PyTorch Lightning API

import torch
from centernet_lightning.models import CenterNet

model = CenterNet.load_from_checkpoint("path/to/checkpoint.ckpt")
model.to_onnx("model.onnx", torch.rand((1,3,512,512)))      # export to ONNX
model.to_torchscript("model.pt", method="trace")            # export to TorchScript. scripting might not work

Evaluate a trained model

WIP

Training CenterNet

You can train CenterNet with the provided train script train.py and a config file.

python train.py --config "configs/coco_resnet34.yaml"

See sample config files at configs/. To customize training, see training.md

Datasets

The following dataset formats are supported:

Detection:

Tracking:

To see how to use each dataset type, see datasets.md