-
Notifications
You must be signed in to change notification settings - Fork 10
/
eval_cam.py
59 lines (49 loc) · 2.35 KB
/
eval_cam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import numpy as np
import os
import argparse
from PIL import Image
from tool import pyutils
from data import data_voc, data_coco
from tqdm import tqdm
from torch.utils.data import DataLoader
from chainercv.evaluations import calc_semantic_segmentation_confusion
def run(args, predict_dir, num_cls):
preds = []
masks = []
n_images = 0
for iter, pack in tqdm(enumerate(dataloader)):
n_images += 1
cam_dict = np.load(os.path.join(predict_dir, pack['name'][0] + '.npy'), allow_pickle=True).item()
cams = cam_dict['IS_CAM']
keys = cam_dict['keys']
cls_labels = np.argmax(cams, axis=0)
cls_labels = keys[cls_labels]
preds.append(cls_labels.copy())
mask = np.array(Image.open(os.path.join(args.gt_path, pack['name'][0] + '.png')))
masks.append(mask.copy())
confusion = calc_semantic_segmentation_confusion(preds, masks)[:num_cls, :num_cls]
gtj = confusion.sum(axis=1)
resj = confusion.sum(axis=0)
gtjresj = np.diag(confusion)
denominator = gtj + resj - gtjresj
iou = gtjresj / denominator
print({'iou': iou, 'miou': np.nanmean(iou)})
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", default="voc", type=str)
parser.add_argument("--gt_path", default='../PascalVOC2012/VOCdevkit/VOC2012/SegmentationClass', type=str)
parser.add_argument('--session_name', default="exp", type=str)
args = parser.parse_args()
assert args.dataset in ['voc', 'coco'], 'Dataset must be voc or coco in this project.'
if args.dataset == 'voc':
dataset_root = '../PascalVOC2012/VOCdevkit/VOC2012'
num_cls = 21
dataset = data_voc.VOC12ImageDataset('data/train_' + args.dataset + '.txt', voc12_root=dataset_root, img_normal=None, to_torch=False)
elif args.dataset == 'coco':
args.gt_path = "../ms_coco_14&15/SegmentationClass/train2014/"
dataset_root = "../ms_coco_14&15/images"
num_cls = 81
dataset = data_coco.COCOImageDataset('data/train_' + args.dataset + '.txt', coco_root=dataset_root, img_normal=None, to_torch=False)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=True, drop_last=True)
pyutils.Logger(os.path.join(args.session_name, 'eval_' + args.session_name + '.log'))
run(args, args.session_name + "/npy", num_cls)