-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtest_net.py
190 lines (155 loc) · 7.11 KB
/
test_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Testing Script for cvpods.
This scripts reads a given config file and runs the training or evaluation.
It is an entry point that is made to train standard models in cvpods.
In order to let one script support training of many models,
this script contains logic that are specific to these built-in models and therefore
may not be suitable for your own project.
For example, your research project perhaps only needs a single "evaluator".
Therefore, we recommend you to use cvpods as an library and take
this file as an example of how to use the library.
You may want to write your own script with your datasets and other customizations.
"""
import glob
import logging
import os
import re
import sys
from collections import OrderedDict
from pprint import pformat
from torch.nn.parallel import DistributedDataParallel
from cvpods.checkpoint import DefaultCheckpointer
from cvpods.engine import RUNNERS, default_argument_parser, default_setup, launch
from cvpods.evaluation import build_evaluator, verify_results
from cvpods.modeling import GeneralizedRCNN, GeneralizedRCNNWithTTA, TTAWarper
from cvpods.utils import PathManager, comm
sys.path.insert(0, '.')
from config import config # noqa: E402
from net import build_model # noqa: E402
sys.path.insert(0, '.')
def runner_decrator(cls):
"""
We use the "DefaultRunner" which contains a number pre-defined logic for
standard training workflow. They may not work for you, especially if you
are working on a new research project. In that case you can use the cleaner
"SimpleRunner", or write your own training loop.
"""
def custom_build_evaluator(cls, cfg, dataset_name, dataset, output_folder=None):
"""
Create evaluator(s) for a given dataset.
This uses the special metadata "evaluator_type" associated with each builtin dataset.
For your own dataset, you can simply create an evaluator manually in your
script and do not have to worry about the hacky if-else logic here.
"""
dump_test = config.GLOBAL.DUMP_TEST
return build_evaluator(cfg, dataset_name, dataset, output_folder, dump=dump_test)
def custom_test_with_TTA(cls, cfg, model):
logger = logging.getLogger("cvpods.runner")
# In the end of training, run an evaluation with TTA
# Only support some R-CNN models.
logger.info("Running inference with test-time augmentation ...")
module = model
if isinstance(module, DistributedDataParallel):
module = model.module
if isinstance(module, GeneralizedRCNN):
model = GeneralizedRCNNWithTTA(cfg, model)
else:
model = TTAWarper(cfg, model)
res = cls.test(cfg, model, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA"))
res = OrderedDict({k + "_TTA": v for k, v in res.items()})
return res
cls.build_evaluator = classmethod(custom_build_evaluator)
cls.test_with_TTA = classmethod(custom_test_with_TTA)
return cls
def test_argument_parser():
parser = default_argument_parser()
parser.add_argument("--start-iter", type=int, default=None, help="start iter used to test")
parser.add_argument("--end-iter", type=int, default=None, help="end iter used to test")
parser.add_argument("--debug", action="store_true", help="use debug mode or not")
parser.add_argument("--double-rpn", action="store_false", default=True)
return parser
def filter_by_iters(file_list, start_iter, end_iter):
# sort file_list by modified time
if file_list[0].startswith("s3://"):
file_list.sort(key=lambda x: PathManager.stat(x).m_date)
else:
file_list.sort(key=os.path.getmtime)
if start_iter is None:
if end_iter is None:
# use latest ckpt if start_iter and end_iter are not given
return [file_list[-1]]
else:
start_iter = 0
elif end_iter is None:
end_iter = float("inf")
iter_infos = [re.split(r"model_|\.pth", f)[-2] for f in file_list]
keep_list = [0] * len(iter_infos)
start_index = 0
if "final" in iter_infos and iter_infos[-1] != "final":
start_index = iter_infos.index("final")
for i in range(len(iter_infos) - 1, start_index, -1):
if iter_infos[i] == "final":
if end_iter == float("inf"):
keep_list[i] = 1
elif float(start_iter) < float(iter_infos[i]) < float(end_iter):
keep_list[i] = 1
if float(iter_infos[i - 1]) > float(iter_infos[i]):
break
return [filename for keep, filename in zip(keep_list, file_list) if keep == 1]
def get_valid_files(args, cfg, logger):
if "MODEL.WEIGHTS" in args.opts:
model_weights = cfg.MODEL.WEIGHTS
assert PathManager.exists(model_weights), "{} not exist!!!".format(model_weights)
return [model_weights]
file_list = glob.glob(os.path.join(cfg.OUTPUT_DIR, "model_*.pth"))
if len(file_list) == 0: # local file invalid, get it from oss
model_prefix = cfg.OUTPUT_DIR.split("cvpods_playground")[-1][1:]
remote_file_path = os.path.join(cfg.OSS.DUMP_PREFIX, model_prefix)
logger.warning(
"No checkpoint file was found locally, try to "
f"load the corresponding dump file on OSS site: {remote_file_path}."
)
file_list = [
str(filename) for filename in PathManager.ls(remote_file_path)
if re.match(r"model_.+\.pth", filename.name) is not None
]
assert len(file_list) != 0, "No valid file found on OSS"
file_list = filter_by_iters(file_list, args.start_iter, args.end_iter)
assert file_list, "No checkpoint valid in {}.".format(cfg.OUTPUT_DIR)
logger.info("All files below will be tested in order:\n{}".format(pformat(file_list)))
return file_list
def main(args):
config.merge_from_list(args.opts)
cfg, logger = default_setup(config, args)
if args.debug:
batches = int(cfg.SOLVER.IMS_PER_DEVICE * args.num_gpus)
if cfg.SOLVER.IMS_PER_BATCH != batches:
cfg.SOLVER.IMS_PER_BATCH = batches
logger.warning("SOLVER.IMS_PER_BATCH is changed to {}".format(batches))
valid_files = get_valid_files(args, cfg, logger)
# * means all if need specific format then *.csv
for current_file in valid_files:
cfg.MODEL.WEIGHTS = current_file
model = build_model(cfg, args.double_rpn)
DefaultCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
cfg.MODEL.WEIGHTS, resume=args.resume
)
if cfg.TEST.AUG.ENABLED:
res = runner_decrator(RUNNERS.get(cfg.TRAINER.NAME)).test_with_TTA(cfg, model)
else:
res = runner_decrator(RUNNERS.get(cfg.TRAINER.NAME)).test(cfg, model)
if comm.is_main_process():
verify_results(cfg, res)
# return res
if __name__ == "__main__":
args = test_argument_parser().parse_args()
print("Command Line Args:", args)
launch(
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
args=(args,),
)