CenterNet is a strong single-stage, single-scale, and anchor-free object detector. This implementation is built with PyTorch Lightning, supports TorchScript and ONNX export, and has modular design to make customizing components simple.
References
To read more about the architecture and code structure of this implementation, see implementation.md
Dependencies
conda install pytorch torchvision cudatoolkit=11.3 -c pytorch
pip install pytorch-lightning pycocotools albumentations
pip install git+https://github.com/gau-nernst/vision-toolbox.git # backbones and necks
pip install filterpy git+https://github.com/JonathonLuiten/TrackEval.git # for FairMOT
pip install jsonargparse[signatures] # for training
Import build_centernet
from models
to build a CenterNet model from a YAML file. Sample config files are provided in the configs/
directory.
from centernet_lightning.models import build_centernet
model = build_centernet("configs/coco_resnet34.yaml")
You also can load a CenterNet model directly from a checkpoint thanks to PyTorch Lightning.
from centernet_lightning.models import CenterNet
model = CenterNet.load_from_checkpoint("path/to/checkpoint.ckpt")
Use CenterNet.inference_detection()
or CenterNet.inference_tracking()
model = ... # create a model as above
img_dir = "path/to/img/dir"
detections = model.inference_detection(img_dir, num_detections=100)
detections
is a dictionary with the following keys:
Key | Description | Shape |
---|---|---|
bboxes |
bounding boxes in x1y1x2y2 format | (num_images x num_detections x 4) |
labels |
class labels | (num_images x num_detections) |
scores |
confidence scores | (num_images x num_detections) |
Results are np.ndarray
, ready for post-processing.
This is useful when you use CenterNet
in your own applications
import numpy as np
import torch
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
# read image
img = cv2.imread("path/to/image")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# apply pre-processing: resize to 512x512 and normalize with ImageNet statistics
# use torchvision.transforms should work also
transforms = A.Compose([
A.Resize(height=512, width=512),
A.Normalize(),
ToTensorV2()
])
img = transforms(image=img)["image"]
# create a model as above and put it in evaluation mode
model = ...
model.eval()
# turn off gradient calculation and do forward pass
with torch.no_grad():
encoded_outputs = model(img.unsqueeze(0))
detections = model.gather_detection2d(encoded_outputs)
detections
has the same format as above, but the values are torch.Tensor
.
Note: Due to data augmentations during training, the model is robust enough to not need ImageNet normalization in inference. You can normalize input image to [0,1]
and CenterNet should still work fine.
CenterNet
is export-friendly. You can directly export a trained model to ONNX or TorchScript (only tracing) using PyTorch Lightning API
import torch
from centernet_lightning.models import CenterNet
model = CenterNet.load_from_checkpoint("path/to/checkpoint.ckpt")
model.to_onnx("model.onnx", torch.rand((1,3,512,512))) # export to ONNX
model.to_torchscript("model.pt", method="trace") # export to TorchScript. scripting might not work
WIP
You can train CenterNet with the provided train script train.py
and a config file.
python train.py --config "configs/coco_resnet34.yaml"
See sample config files at configs/. To customize training, see training.md
The following dataset formats are supported:
Detection:
Tracking:
To see how to use each dataset type, see datasets.md