-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
134 lines (113 loc) · 6.19 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import argparse
import logging.config
import os
import sys
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from core.test import test
from core.train import train
from core.utils import init_logger, make_results_dir
if __name__ == '__main__':
# Lets gather arguments
parser = argparse.ArgumentParser(description='TD3 with variable action repeat')
parser.add_argument('--env', required=True,
help='Name of the environment')
parser.add_argument('--result_dir', default=os.path.join(os.getcwd(), 'results'),
help="Directory Path to store results (default: %(default)s)")
parser.add_argument('--case', required=True, choices=['dm_control', 'mujoco', 'box2d', 'classic_control'],
help="It's used for switching between different domains(default: %(default)s)")
parser.add_argument('--opr', required=True, choices=['train', 'test'])
parser.add_argument('--no_cuda', action='store_true', default=False,
help='no cuda usage (default: %(default)s)')
parser.add_argument('--render', action='store_true', default=False,
help='Renders the environment (default: %(default)s)')
parser.add_argument('--force', action='store_true', default=False,
help='Overrides past results (default: %(default)s)')
parser.add_argument('--seed', type=int, default=0, help='seed (default: %(default)s)')
parser.add_argument('--use_wandb', action='store_true', default=False,
help='Use Weight and bias visualization lib (default: %(default)s)')
parser.add_argument('--action_repeat_mode', choices=['fixed', 'variable'], default='variable',
help='Mode of action repeat (default: %(default)s). '
'In the fixed mode, an action is simply repeated for fixed number of time-steps'
' whereas in the "variable" mode, we learn the repeat-count')
parser.add_argument('--fixed_action_repeat', type=int, default=None,
help='Action Repeat (default: %(default)s)')
parser.add_argument('--test_episodes', type=int, default=1,
help='Evaluation episode count (default: %(default)s)')
parser.add_argument('--wandb_dir', default=os.path.join(os.getcwd(), 'wandb'),
help="Directory Path to store results (default: %(default)s)")
parser.add_argument('--restore-model-from-wandb', action='store_true', default=False,
help='restore model from wandb run. (default: %(default)s)')
parser.add_argument('--wandb-run-id', type=str,
help='Wandb run if for restoring model (default: %(default)s)')
# Process arguments
args = parser.parse_args()
args.device = 'cuda' if (not args.no_cuda) and torch.cuda.is_available() else 'cpu'
# seeding random iterators
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.device == 'cuda':
torch.cuda.manual_seed(args.seed)
# import corresponding configuration , neural networks and envs
if args.case == 'classic_control':
from config.classic_control import run_config
elif args.case == 'box2d':
from config.box2d import run_config
elif args.case == 'mujoco':
from config.mujoco import run_config
elif args.case == 'dm_control':
from config.dm_control import run_config
else:
raise Exception('Invalid --case option.')
# set config as per arguments
run_config.set_config(args)
log_base_path = make_results_dir(run_config.exp_path, args)
if args.use_wandb:
os.makedirs(args.wandb_dir, exist_ok=True)
# set-up logger
init_logger(log_base_path)
logging.getLogger('root').info('cmd args:{}'.format(' '.join(sys.argv[1:]))) # log command line arguments.
try:
if args.opr == 'train':
if args.use_wandb:
os.makedirs(args.wandb_dir, exist_ok=True)
os.environ['WANDB_DIR'] = str(args.wandb_dir)
import wandb
wandb.init(job_type='train', dir=args.wandb_dir, group=args.case + ':' + args.env,
project="variable-td3", config=run_config.get_hparams(), sync_tensorboard=True)
summary_writer = SummaryWriter(run_config.exp_path, flush_secs=60 * 1) # flush every 1 minutes
train(run_config, summary_writer)
summary_writer.flush()
summary_writer.close()
if args.use_wandb:
wandb.join()
elif args.opr == 'test':
# restore from wandb
model_path = run_config.model_path
if args.restore_model_from_wandb:
assert args.wandb_run_id is not None, 'wandb run id cannot be {}'.format(args.wandb_run_id)
import wandb
root, name = os.path.split(model_path)
wandb.restore(name=name, run_path=args.wandb_run_id, replace=True, root=root)
# load model
assert model_path, 'model not found: {}'.format(model_path)
model = run_config.get_uniform_network()
model = model.to('cpu')
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
if args.render and args.case == 'mujoco':
# Ref: https://github.com/openai/mujoco-py/issues/390
from mujoco_py import GlfwContext
GlfwContext(offscreen=True)
env = run_config.new_game()
test_score, test_repeat_counts = test(env, model, args.test_episodes,
device='cpu', render=args.render,
save_test_data=True, save_path=run_config.test_data_path,
recording_path=run_config.recording_path)
env.close()
logging.getLogger('test').info('Test Score: {}'.format(test_score))
else:
raise ValueError('"--opr {}" is not implemented ( or not valid)'.format(args.opr))
except Exception as e:
logging.getLogger('root').error(e, exc_info=True)
logging.shutdown()