-
Notifications
You must be signed in to change notification settings - Fork 0
/
calc_dist_feats.py
106 lines (88 loc) · 4.1 KB
/
calc_dist_feats.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import glob
import os
import warnings
import argparse
import tqdm
import numpy as np
import pandas as pd
import torch
from torchvision import transforms as T
import PIL
from piq import FID, KID
from piq.feature_extractors import InceptionV3
@torch.no_grad()
def compute_feats(
filenames,
feature_extractor: torch.nn.Module = None,
device: str = 'cuda') -> torch.Tensor:
r"""Generate low-dimensional image descriptors
Args:
loader: Should return dict with key `images` in it
feature_extractor: model used to generate image features, if None use `InceptionNetV3` model.
Model should return a list with features from one of the network layers.
out_features: size of `feature_extractor` output
device: Device on which to compute inference of the model
"""
if feature_extractor is None:
print('WARNING: default feature extractor (InceptionNet V3) is used.')
feature_extractor = InceptionV3()
else:
assert isinstance(feature_extractor, torch.nn.Module), \
f"Feature extractor must be PyTorch module. Got {type(feature_extractor)}"
feature_extractor.to(device)
feature_extractor.eval()
total_feats = []
for filename in tqdm.tqdm(filenames):
try:
images = T.ToTensor()(np.asarray(PIL.Image.open(filename))).unsqueeze(0)
except PIL.UnidentifiedImageError:
warnings.warn(f"Ignoring image {filename}, could not be read")
continue
images = images.float().to(device)
# Get features
features = feature_extractor(images)
assert len(features) == 1, \
f"feature_encoder must return list with features from one layer. Got {len(features)}"
total_feats.append(features[0].view(1, -1))
return torch.cat(total_feats, dim=0)
def get_or_create_feats(filelist, output_file):
if os.path.isfile(output_file):
return np.load(output_file)
else:
print(f"Processing for {output_file}")
extr = InceptionV3()
feats = compute_feats(filelist, extr)
saved_feats = feats.cpu().numpy()
np.save(output_file, saved_feats)
return saved_feats
@torch.no_grad()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--gt-dir", type=str, help="Path to ground truth directory", required=True)
parser.add_argument("--enh-dir", type=str, help="Path to corrupted/enhanced directory", required=True)
parser.add_argument("--cuda-device", type=int, help="CUDA device to use", default=0)
args = parser.parse_args()
with torch.cuda.device(args.cuda_device):
gt_dir, enh_dir = args.gt_dir, args.enh_dir
gt_feats_file = os.path.join(gt_dir, "feats.npy")
enh_feats_file = os.path.join(enh_dir, "feats.npy")
gt_filelist = list(sorted(glob.glob(os.path.join(gt_dir, "*.png"))))
enh_filelist = list(sorted(glob.glob(os.path.join(enh_dir, "*.png"))))
gt_basenames = [os.path.basename(f) for f in gt_filelist]
enh_basenames = [os.path.basename(f) for f in enh_filelist]
if not gt_basenames == enh_basenames:
diff = set(gt_basenames) ^ set(enh_basenames)
raise ValueError(f"File list in {enh_dir} does not match file list in {gt_dir}, stopping! Diff (XOR): {diff}")
print(f"Loading or calculating ground-truth features from dir {gt_dir}...")
gt_feats = torch.from_numpy(get_or_create_feats(gt_filelist, gt_feats_file))
print(f"Calculating features for dirs {enh_dir}, using {gt_dir} as ground truth dir. "
f"Will save feats to <dir>/feats.npy, and save KID/FID scores to <dir>/scores.csv.")
scores_file = os.path.join(enh_dir, "scores.csv")
enh_feats = torch.from_numpy(get_or_create_feats(enh_filelist, enh_feats_file))
fid = FID().compute_metric(gt_feats, enh_feats).item()
kid = KID().compute_metric(gt_feats, enh_feats).item()
print(f"FID: {fid}, KID: {kid}")
pd.DataFrame({"FID": [fid], "KID": [kid]}).to_csv(scores_file, index=False)
print("===================================== Done!")
if __name__ == "__main__":
main()