-
Notifications
You must be signed in to change notification settings - Fork 232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Regression Benchmark #146
Open
nbei
wants to merge
10
commits into
open-mmlab:master
Choose a base branch
from
nbei:regress
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
5cbdd86
initial regress script
nbei 716da33
revise running command
nbei e85c7b4
fix bug in ckpt path
nbei 0191057
fix bug in ckpt path
nbei a878cf6
update command
nbei 3381547
add multiple test
nbei 1fa83fe
fix bug in styleganv1 256
nbei e03fcbb
update parse_test_log.py
nbei b0d871e
update PR string format
nbei 013d35b
Merge branch 'master' of github.com:open-mmlab/mmgeneration into regress
nbei File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,278 @@ | ||
import argparse | ||
import json | ||
import os | ||
|
||
import mmcv | ||
import yaml | ||
|
||
# --------------------------------------------------------------------------- # | ||
# | ||
# Tips for adopting the test_benchmark.sh and parse_test_log.py: | ||
# | ||
# 1. The directory for pre-trained checkpoints should follow the structure | ||
# of ``./configs`` folder, especially the name of each method. | ||
# 2. For the regression report, you can quickly check the ``regress_status`` | ||
# a) ``Missing Eval Info``: Cannot read the information table from log file | ||
# correctly. Please further check the log file and testing output. | ||
# b) ``Failed``: The evaluation metric cannot match the value in our | ||
# metafile. | ||
# c) ``Pass``: Successfully pass the regression test. | ||
# 3. We should check the representation of the result value in metafile and the | ||
# generated log table. (Related to convert_str_metric_value() function) | ||
# 4. The matric name should be mapped to the standard one. Please check | ||
# ``metric_name_mapping`` to ensure the metric name in the keys of the dict. | ||
# 5. Pay attention to the format of numerical value. For instance, the | ||
# precision value can be 0.69xx or 69.xx. | ||
# | ||
# --------------------------------------------------------------------------- # | ||
|
||
# used to store log infos: The key indicates the path of each log file, while | ||
# the value contains meta information (method category, ckpt name) and the | ||
# parsed log info | ||
data_dict = {} | ||
|
||
metric_name_mapping = { | ||
'FID50k': 'FID', | ||
'FID': 'FID', | ||
'P&R': 'PR', | ||
'P&R50k': 'PR', | ||
'PR': 'PR', | ||
'PR50k': 'PR' | ||
} | ||
|
||
tolerance = 2.0 | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description='Train a GAN model') | ||
parser.add_argument('logdir', help='evaluation log path') | ||
parser.add_argument( | ||
'--out', | ||
type=str, | ||
default='./work_dirs/', | ||
help='output directory for benchmark information') | ||
|
||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def read_metafile(filepath): | ||
with open(filepath, 'r') as f: | ||
data = yaml.safe_load(f) | ||
|
||
return data | ||
|
||
|
||
def is_number(s): | ||
try: | ||
float(s) | ||
return True | ||
except ValueError: | ||
pass | ||
|
||
return False | ||
|
||
|
||
def read_table_from_log(filepath): | ||
"""Read the evaluation table from the log file. | ||
|
||
.. note:: We assume the log file only records one evaluation procedure. | ||
""" | ||
# table_info will contain 2 elements. The first element is the head of the | ||
# table indicating the meaning of each column. The second element is the | ||
# content of each column, e.g., training configuration, and fid value. | ||
table_info = [] | ||
|
||
with open(filepath, 'r') as f: | ||
for line in f.readlines(): | ||
if 'mmgen' in line or 'INFO' in line: | ||
# running message | ||
continue | ||
if line[0] == '|' and '+' in line: | ||
# table split line | ||
continue | ||
if line[0] == '|': | ||
# useful content | ||
line = line.strip() | ||
line_split = line.split('|') | ||
line_split = [x.strip() for x in line_split] | ||
table_info.append(line_split) | ||
|
||
if len(table_info) < 2: | ||
print(f'Please check the log file: {filepath}. Cannot get the eval' | ||
' information correctly.') | ||
elif len(table_info) > 3: | ||
print(f'Got more than 2 lines from the eval table in {filepath}') | ||
else: | ||
assert len(table_info[0]) == len( | ||
table_info[1]), 'The eval table cannot be aligned' | ||
|
||
return table_info | ||
|
||
|
||
def convert_str_metric_value(value): | ||
if isinstance(value, float): | ||
return value | ||
|
||
if isinstance(value, int): | ||
return float(value) | ||
|
||
# otherwise, we assume the value will be string format | ||
|
||
# Case: 3.42 (1.xxx/2.xxx) -- used in FID | ||
if '(' in value and ')' in value: | ||
split = value.split('(') | ||
res = split[0].strip() | ||
return float(res) | ||
|
||
# Case: 60.xx/40.xx -- used in PR metric | ||
elif '/' in value: | ||
split = value.split('/') | ||
|
||
return [float(x.strip()) for x in split] | ||
# Case: precision: 69.59999918937683, recall:40.200000643730164 | ||
elif ',' in value: | ||
split = [x.strip() for x in value.split(',')] | ||
res = [] | ||
for x in split: | ||
if ':' in x: | ||
num_str = x.split(':')[1].strip() | ||
res.append(float(num_str)) | ||
elif is_number(x): | ||
res.append(float(x)) | ||
return res | ||
else: | ||
try: | ||
res = float(value) | ||
return res | ||
except Exception as err: | ||
print(f'Cannot convert str value {value} to float') | ||
print(f'Unexpected {err}, {type(err)}') | ||
raise err | ||
|
||
|
||
def compare_eval_orig_info(orig_info, eval_info): | ||
flag = True | ||
for k, v in eval_info.items(): | ||
orig_value = orig_info[k] | ||
if isinstance(v, float): | ||
if abs(v - orig_value) > tolerance: | ||
print(v, orig_value) | ||
flag = False | ||
break | ||
elif isinstance(v, list): | ||
for tmp_v, temp_orig in zip(v, orig_value): | ||
if abs(tmp_v - temp_orig) > tolerance: | ||
print(v, orig_value) | ||
flag = False | ||
break | ||
if not flag: | ||
break | ||
else: | ||
raise RuntimeError(f'Cannot parse compare eval_value: {v} and ' | ||
f'orig_value: {orig_value}.') | ||
|
||
return flag | ||
|
||
|
||
def check_info_from_metafile(meta_info): | ||
"""Check whether eval information matches the description from metafile. | ||
|
||
TODO: Set a dict containing the tolerance for different configurations. | ||
""" | ||
method_cat = meta_info['method'] | ||
ckpt_name = meta_info['ckpt_name'] | ||
metafile_path = os.path.join('./configs', method_cat, 'metafile.yml') | ||
|
||
meta_data_orig = read_metafile(metafile_path)['Models'] | ||
results_orig = None | ||
|
||
for info in meta_data_orig: | ||
if ckpt_name in info['Weights']: | ||
results_orig = info['Results'] | ||
break | ||
|
||
if results_orig is None: | ||
print(f'Cannot find related models for {ckpt_name}') | ||
return False | ||
|
||
metric_value_orig = {} | ||
results_metric_orig = None | ||
for info in results_orig: | ||
if 'Metrics' in info: | ||
results_metric_orig = info['Metrics'] | ||
break | ||
assert results_metric_orig is not None, 'Cannot find Metrics in metafile.' | ||
# get the original metric value | ||
for k, v in results_metric_orig.items(): | ||
if k in metric_name_mapping: | ||
metric_value_orig[ | ||
metric_name_mapping[k]] = convert_str_metric_value(v) | ||
|
||
assert len(metric_value_orig | ||
) > 0, f'Cannot get metric value in metafile for {ckpt_name}' | ||
|
||
# get the metric value from evaluation table | ||
eval_info = meta_info['eval_table'] | ||
metric_value_eval = {} | ||
for i, name in enumerate(eval_info[0]): | ||
if name in metric_name_mapping: | ||
metric_value_eval[ | ||
metric_name_mapping[name]] = convert_str_metric_value( | ||
eval_info[1][i]) | ||
assert len(metric_value_eval | ||
) > 0, f'Cannot get metric value in eval table: {eval_info}' | ||
|
||
# compare eval info and the original info from metafile | ||
return compare_eval_orig_info(metric_value_orig, metric_value_eval) | ||
|
||
|
||
def get_log_files(args): | ||
"""Got all of the log files from the given args.logdir. | ||
|
||
This function is used to initialize ``data_dict``. | ||
""" | ||
log_paths = mmcv.scandir(args.logdir, '.txt', recursive=True) | ||
log_paths = [os.path.join(args.logdir, x) for x in log_paths] | ||
|
||
# construct data dict | ||
for log in log_paths: | ||
splits = log.split('/') | ||
method = splits[-2] | ||
log_name = splits[-1] | ||
ckpt_name = log_name.replace('_eval_log.txt', '.pth') | ||
|
||
data_dict[log] = dict( | ||
method=method, log_name=log_name, ckpt_name=ckpt_name) | ||
|
||
print(f'Got {len(data_dict)} logs from {args.logdir}') | ||
|
||
return data_dict | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
get_log_files(args) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe |
||
|
||
# process log files | ||
for log_path, meta_info in data_dict.items(): | ||
table_info = read_table_from_log(log_path) | ||
|
||
# only deal with valid table_info | ||
if len(table_info) == 2: | ||
meta_info['eval_table'] = table_info | ||
meta_info['regress_status'] = 'Pass' if check_info_from_metafile( | ||
meta_info) else 'Failed' | ||
else: | ||
meta_info['regress_status'] = 'Missing Eval Info' | ||
|
||
mmcv.mkdir_or_exist(args.out) | ||
with open(os.path.join(args.out, 'test_regression_report.json'), 'w') as f: | ||
json.dump(data_dict, f) | ||
|
||
print('-------------- Regression Report --------------') | ||
print(data_dict) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
PARTITION=$1 | ||
CHECKPOINT_DIR=$2 | ||
|
||
# stylegan2 | ||
echo 'configs/styleganv2/stylegan2_c2_ffhq_256_b4x8_800k.py' & | ||
# GPUS=1 for the test with data parallel | ||
GPUS=1 bash tools/slurm_eval.sh ${PARTITION} test-benchmark \ | ||
configs/styleganv2/stylegan2_c2_ffhq_256_b4x8_800k.py \ | ||
${CHECKPOINT_DIR}/stylegan2/stylegan2_c2_ffhq_256_b4x8_20210407_160709-7890ae1f.pth \ | ||
--online & | ||
|
||
# stylegan v1 | ||
echo 'configs/styleganv1/styleganv1_ffhq_256_g8_25Mimg.py' & | ||
# GPUS=1 for the test with data parallel | ||
GPUS=1 bash tools/slurm_eval.sh ${PARTITION} test-benchmark \ | ||
configs/styleganv1/styleganv1_ffhq_256_g8_25Mimg.py \ | ||
${CHECKPOINT_DIR}/styleganv1/styleganv1_ffhq_256_g8_25Mimg_20210407_161748-0094da86.pth \ | ||
--online & |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
matric
? MaybeMetric
?