通过本文档,你将会知道如何使用自定义数据集对预先定义好的模型进行推理,测试以及训练。我们使用 balloon dataset 作为例子来描述整个过程。
基本步骤如下:
- 准备自定义数据集
- 准备配置文件
- 在自定义数据集上进行训练,测试和推理。
MMDetection 一共支持三种形式应用新数据集:
- 将数据集重新组织为 COCO 格式。
- 将数据集重新组织为一个中间格式。
- 实现一个新的数据集。
我们通常建议使用前面两种方法,因为它们通常来说比第三种方法要简单。
在本文档中,我们展示一个例子来说明如何将数据转化为 COCO 格式。
注意:MMDetection 现只支持对 COCO 格式的数据集进行 mask AP 的评测。
所以用户如果要进行实例分割,只能将数据转成 COCO 格式。
用于实例分割的 COCO 数据集格式如下所示,其中的键(key)都是必要的,参考这里来获取更多细节。
{
"images": [image],
"annotations": [annotation],
"categories": [category]
}
image = {
"id": int,
"width": int,
"height": int,
"file_name": str,
}
annotation = {
"id": int,
"image_id": int,
"category_id": int,
"segmentation": RLE or [polygon],
"area": float,
"bbox": [x,y,width,height],
"iscrowd": 0 or 1,
}
categories = [{
"id": int,
"name": str,
"supercategory": str,
}]
现在假设我们使用 balloon dataset。
下载了数据集之后,我们需要实现一个函数将标注格式转化为 COCO 格式。然后我们就可以使用已经实现的 COCODataset
类来加载数据并进行训练以及评测。
如果你浏览过新数据集,你会发现格式如下:
{'base64_img_data': '',
'file_attributes': {},
'filename': '34020010494_e5cb88e1c4_k.jpg',
'fileref': '',
'regions': {'0': {'region_attributes': {},
'shape_attributes': {'all_points_x': [1020,
1000,
994,
1003,
1023,
1050,
1089,
1134,
1190,
1265,
1321,
1361,
1403,
1428,
1442,
1445,
1441,
1427,
1400,
1361,
1316,
1269,
1228,
1198,
1207,
1210,
1190,
1177,
1172,
1174,
1170,
1153,
1127,
1104,
1061,
1032,
1020],
'all_points_y': [963,
899,
841,
787,
738,
700,
663,
638,
621,
619,
643,
672,
720,
765,
800,
860,
896,
942,
990,
1035,
1079,
1112,
1129,
1134,
1144,
1153,
1166,
1166,
1150,
1136,
1129,
1122,
1112,
1084,
1037,
989,
963],
'name': 'polygon'}}},
'size': 1115004}
标注文件时是 JSON 格式的,其中所有键(key)组成了一张图片的所有标注。
其中将 balloon dataset 转化为 COCO 格式的代码如下所示。
import os.path as osp
import mmcv
def convert_balloon_to_coco(ann_file, out_file, image_prefix):
data_infos = mmcv.load(ann_file)
annotations = []
images = []
obj_count = 0
for idx, v in enumerate(mmcv.track_iter_progress(data_infos.values())):
filename = v['filename']
img_path = osp.join(image_prefix, filename)
height, width = mmcv.imread(img_path).shape[:2]
images.append(dict(
id=idx,
file_name=filename,
height=height,
width=width))
bboxes = []
labels = []
masks = []
for _, obj in v['regions'].items():
assert not obj['region_attributes']
obj = obj['shape_attributes']
px = obj['all_points_x']
py = obj['all_points_y']
poly = [(x + 0.5, y + 0.5) for x, y in zip(px, py)]
poly = [p for x in poly for p in x]
x_min, y_min, x_max, y_max = (
min(px), min(py), max(px), max(py))
data_anno = dict(
image_id=idx,
id=obj_count,
category_id=0,
bbox=[x_min, y_min, x_max - x_min, y_max - y_min],
area=(x_max - x_min) * (y_max - y_min),
segmentation=[poly],
iscrowd=0)
annotations.append(data_anno)
obj_count += 1
coco_format_json = dict(
images=images,
annotations=annotations,
categories=[{'id':0, 'name': 'balloon'}])
mmcv.dump(coco_format_json, out_file)
使用如上的函数,用户可以成功将标注文件转化为 JSON 格式,之后可以使用 CocoDataset
对模型进行训练和评测。
第二步需要准备一个配置文件来成功加载数据集。假设我们想要用 balloon dataset 来训练配备了 FPN 的 Mask R-CNN ,如下是我们的配置文件。假设配置文件命名为 mask_rcnn_r50_caffe_fpn_mstrain-poly_1x_balloon.py
,相应保存路径为 configs/balloon/
,配置文件内容如下所示。
# 这个新的配置文件继承自一个原始配置文件,只需要突出必要的修改部分即可
_base_ = 'mask_rcnn/mask_rcnn_r50_caffe_fpn_mstrain-poly_1x_coco.py'
# 我们需要对头中的类别数量进行修改来匹配数据集的标注
model = dict(
roi_head=dict(
bbox_head=dict(num_classes=1),
mask_head=dict(num_classes=1)))
# 修改数据集相关设置
dataset_type = 'CocoDataset'
classes = ('balloon',)
data = dict(
train=dict(
img_prefix='balloon/train/',
classes=classes,
ann_file='balloon/train/annotation_coco.json'),
val=dict(
img_prefix='balloon/val/',
classes=classes,
ann_file='balloon/val/annotation_coco.json'),
test=dict(
img_prefix='balloon/val/',
classes=classes,
ann_file='balloon/val/annotation_coco.json'))
# 我们可以使用预训练的 Mask R-CNN 来获取更好的性能
load_from = 'checkpoints/mask_rcnn_r50_caffe_fpn_mstrain-poly_3x_coco_bbox_mAP-0.408__segm_mAP-0.37_20200504_163245-42aa3d00.pth'
为了使用新的配置方法来对模型进行训练,你只需要运行如下命令。
python tools/train.py configs/balloon/mask_rcnn_r50_caffe_fpn_mstrain-poly_1x_balloon.py
参考情况 1来获取更多详细的使用方法。
为了测试训练完毕的模型,你只需要运行如下命令。
python tools/test.py configs/balloon/mask_rcnn_r50_caffe_fpn_mstrain-poly_1x_balloon.py work_dirs/mask_rcnn_r50_caffe_fpn_mstrain-poly_1x_balloon.py/latest.pth --eval bbox segm
参考情况 1来获取更多详细的使用方法。