forked from TianzhongSong/keras-FP16-test
-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval_pointnet
27 lines (22 loc) · 949 Bytes
/
eval_pointnet
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from models.pointnet_cls import PointNet
from utils.point_cloud_data_generator import DataGenerator
from keras import backend as K
import argparse
if __name__ == '__main__':
parse = argparse.ArgumentParser()
parse.add_argument('--model', type=str, default='cls', help='supports cls and seg')
parse.add_argument('--dtype', type=str, default='float32')
args = parse.parse_args()
K.set_floatx(args.dtype)
if args.model == 'cls':
nb_classes = 40
data = './data/ply_data_test_cls.h5'
model = PointNet(nb_classes)
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
model.load_weights('./weights/pointnet_cls.h5', by_name=True)
dg = DataGenerator(data, 32, nb_classes, False)
score = model.evaluate_generator(dg.generator(), steps=2468 // 32)
print(score[0])
print(score[1])