-
Notifications
You must be signed in to change notification settings - Fork 162
/
train.py
118 lines (97 loc) · 4.56 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gym
import torch
import json
import os
import yaml
from tqdm import trange
import maml_rl.envs
from maml_rl.metalearners import MAMLTRPO
from maml_rl.baseline import LinearFeatureBaseline
from maml_rl.samplers import MultiTaskSampler
from maml_rl.utils.helpers import get_policy_for_env, get_input_size
from maml_rl.utils.reinforcement_learning import get_returns
def main(args):
with open(args.config, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
if args.output_folder is not None:
if not os.path.exists(args.output_folder):
os.makedirs(args.output_folder)
policy_filename = os.path.join(args.output_folder, 'policy.th')
config_filename = os.path.join(args.output_folder, 'config.json')
with open(config_filename, 'w') as f:
config.update(vars(args))
json.dump(config, f, indent=2)
if args.seed is not None:
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
env = gym.make(config['env-name'], **config.get('env-kwargs', {}))
env.close()
# Policy
policy = get_policy_for_env(env,
hidden_sizes=config['hidden-sizes'],
nonlinearity=config['nonlinearity'])
policy.share_memory()
# Baseline
baseline = LinearFeatureBaseline(get_input_size(env))
# Sampler
sampler = MultiTaskSampler(config['env-name'],
env_kwargs=config.get('env-kwargs', {}),
batch_size=config['fast-batch-size'],
policy=policy,
baseline=baseline,
env=env,
seed=args.seed,
num_workers=args.num_workers)
metalearner = MAMLTRPO(policy,
fast_lr=config['fast-lr'],
first_order=config['first-order'],
device=args.device)
num_iterations = 0
for batch in trange(config['num-batches']):
tasks = sampler.sample_tasks(num_tasks=config['meta-batch-size'])
futures = sampler.sample_async(tasks,
num_steps=config['num-steps'],
fast_lr=config['fast-lr'],
gamma=config['gamma'],
gae_lambda=config['gae-lambda'],
device=args.device)
logs = metalearner.step(*futures,
max_kl=config['max-kl'],
cg_iters=config['cg-iters'],
cg_damping=config['cg-damping'],
ls_max_steps=config['ls-max-steps'],
ls_backtrack_ratio=config['ls-backtrack-ratio'])
train_episodes, valid_episodes = sampler.sample_wait(futures)
num_iterations += sum(sum(episode.lengths) for episode in train_episodes[0])
num_iterations += sum(sum(episode.lengths) for episode in valid_episodes)
logs.update(tasks=tasks,
num_iterations=num_iterations,
train_returns=get_returns(train_episodes[0]),
valid_returns=get_returns(valid_episodes))
# Save policy
if args.output_folder is not None:
with open(policy_filename, 'wb') as f:
torch.save(policy.state_dict(), f)
if __name__ == '__main__':
import argparse
import multiprocessing as mp
parser = argparse.ArgumentParser(description='Reinforcement learning with '
'Model-Agnostic Meta-Learning (MAML) - Train')
parser.add_argument('--config', type=str, required=True,
help='path to the configuration file.')
# Miscellaneous
misc = parser.add_argument_group('Miscellaneous')
misc.add_argument('--output-folder', type=str,
help='name of the output folder')
misc.add_argument('--seed', type=int, default=None,
help='random seed')
misc.add_argument('--num-workers', type=int, default=mp.cpu_count() - 1,
help='number of workers for trajectories sampling (default: '
'{0})'.format(mp.cpu_count() - 1))
misc.add_argument('--use-cuda', action='store_true',
help='use cuda (default: false, use cpu). WARNING: Full upport for cuda '
'is not guaranteed. Using CPU is encouraged.')
args = parser.parse_args()
args.device = ('cuda' if (torch.cuda.is_available()
and args.use_cuda) else 'cpu')
main(args)