-
Notifications
You must be signed in to change notification settings - Fork 3
/
eval.py
83 lines (65 loc) · 4.15 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import time
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import sys
from dataset import ShapeNet
from loss import Loss
from config import Config
from model import ExtrudeNet
from utils import generate_mesh
import argparse
def eval(config):
test_loader = DataLoader(ShapeNet(shapenet_root=config.dataset_root, num_testing_points=config.num_sample_points, balance=True, categories=[config.category,],partition="val"), pin_memory=True, num_workers=24, batch_size=config.test_batch_size_per_gpu*config.num_gpu, shuffle=False, drop_last=True)
device = torch.device("cuda")
model = ExtrudeNet(config).to(device)
pre_train_model_path = './checkpoints/%s/models/model.th' % config.experiment_name
assert os.path.exists(pre_train_model_path), "Cannot find pre-train model for experiment: {}\nNo such a file: {}".format(config.experiment_name, pre_train_model_path)
model.load_state_dict(torch.load('./checkpoints/%s/models/model.th' % config.experiment_name))
# model = nn.DataParallel(model)
print("Let's use", torch.cuda.device_count(), "GPUs!")
criterion = Loss(config)
model.eval()
start_time = time.time()
test_iter = 0
with torch.no_grad():
testloader_t = tqdm(test_loader)
avg_test_loss_recon = avg_test_loss_primitive = avg_test_loss = avg_test_accuracy = avg_test_recall = 0
for surface_pointcloud, testing_points in testloader_t:
surface_pointcloud = surface_pointcloud.to(device)
testing_points = testing_points.to(device)
occupancies, primitive_sdfs, primitive_parameters, support_distances = model(surface_pointcloud.transpose(2,1), testing_points[:,:,:3], is_training=True)
loss_dict = criterion(occupancies, testing_points[:,:,-1], primitive_sdfs, primitive_parameters, support_distances)
predict_occupancies = (occupancies >=0.5).float()
target_occupancies = (testing_points[:,:,-1] >=0.5).float()
accuracy = torch.sum(predict_occupancies*target_occupancies)/torch.sum(target_occupancies)
recall = torch.sum(predict_occupancies*target_occupancies)/(torch.sum(predict_occupancies)+1e-9)
avg_test_loss_recon += loss_dict["loss_recon"].item()
avg_test_loss_primitive += loss_dict["loss_primitive"].item()
avg_test_loss += loss_dict["loss_total"].item()
avg_test_accuracy += accuracy.item()
avg_test_recall += recall.item()
generate_mesh(model, surface_pointcloud.transpose(2,1), config, test_iter)
test_iter += 1
# exit()
avg_test_loss_recon = avg_test_loss_recon / test_iter
test_accuracy = avg_test_accuracy / test_iter
test_recall = avg_test_recall / test_iter
test_fscore = 2*test_accuracy*test_recall/(test_accuracy + test_recall + 1e-6)
print("Evaluating: time: %4.4f, loss_total: %.6f, loss_recon: %.6f, loss_primitive: %.6f, acc: %.6f, recall: %.6f, fscore: %.6f" % (
time.time() - start_time,
avg_test_loss/test_iter,
avg_test_loss_recon / test_iter,
avg_test_loss_primitive/test_iter,
test_accuracy,
test_recall,
test_fscore))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='ExtrudeNet')
parser.add_argument('--config_path', type=str, default='./configs/plane.json', metavar='N',
help='config_path')
args = parser.parse_args()
config = Config((args.config_path))
eval(config)