-
Notifications
You must be signed in to change notification settings - Fork 186
/
Copy patheval.py
72 lines (56 loc) · 2.31 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import cv2 as cv
import numpy as np
import torch
from torchvision import transforms
from tqdm import tqdm
from config import device
from data_gen import data_transforms
from utils import ensure_folder
IMG_FOLDER = 'data/alphamatting/input_lowres'
TRIMAP_FOLDERS = ['data/alphamatting/trimap_lowres/Trimap1', 'data/alphamatting/trimap_lowres/Trimap2',
'data/alphamatting/trimap_lowres/Trimap3']
OUTPUT_FOLDERS = ['images/alphamatting/output_lowres/Trimap1', 'images/alphamatting/output_lowres/Trimap2', 'images/alphamatting/output_lowres/Trimap3', ]
if __name__ == '__main__':
checkpoint = 'BEST_checkpoint.tar'
checkpoint = torch.load(checkpoint)
model = checkpoint['model'].module
model = model.to(device)
model.eval()
transformer = data_transforms['valid']
ensure_folder('images')
ensure_folder('images/alphamatting')
ensure_folder(OUTPUT_FOLDERS[0])
ensure_folder(OUTPUT_FOLDERS[1])
ensure_folder(OUTPUT_FOLDERS[2])
files = [f for f in os.listdir(IMG_FOLDER) if f.endswith('.png')]
for file in tqdm(files):
filename = os.path.join(IMG_FOLDER, file)
img = cv.imread(filename)
print(img.shape)
h, w = img.shape[:2]
x = torch.zeros((1, 4, h, w), dtype=torch.float)
image = img[..., ::-1] # RGB
image = transforms.ToPILImage()(image)
image = transformer(image)
x[0:, 0:3, :, :] = image
for i in range(3):
filename = os.path.join(TRIMAP_FOLDERS[i], file)
print('reading {}...'.format(filename))
trimap = cv.imread(filename, 0)
x[0:, 3, :, :] = torch.from_numpy(trimap.copy() / 255.)
# print(torch.max(x[0:, 3, :, :]))
# print(torch.min(x[0:, 3, :, :]))
# print(torch.median(x[0:, 3, :, :]))
# Move to GPU, if available
x = x.type(torch.FloatTensor).to(device)
with torch.no_grad():
pred = model(x)
pred = pred.cpu().numpy()
pred = pred.reshape((h, w))
pred[trimap == 0] = 0.0
pred[trimap == 255] = 1.0
out = (pred.copy() * 255).astype(np.uint8)
filename = os.path.join(OUTPUT_FOLDERS[i], file)
cv.imwrite(filename, out)
print('wrote {}.'.format(filename))