-
Notifications
You must be signed in to change notification settings - Fork 16
/
train_meta_task.py
382 lines (327 loc) · 16 KB
/
train_meta_task.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
# Adapted from PureJaxRL implementation and minigrid baselines, source:
# https://github.com/lupuandr/explainable-policies/blob/50acbd777dc7c6d6b8b7255cd1249e81715bcb54/purejaxrl/ppo_rnn.py#L4
# https://github.com/lcswillems/rl-starter-files/blob/master/model.py
import os
import shutil
import time
from dataclasses import asdict, dataclass
from functools import partial
from typing import Optional
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import optax
import orbax
import pyrallis
import wandb
from flax.jax_utils import replicate, unreplicate
from flax.training import orbax_utils
from flax.training.train_state import TrainState
from nn import ActorCriticRNN
from utils import Transition, calculate_gae, ppo_update_networks, rollout
import xminigrid
from xminigrid.benchmarks import Benchmark
from xminigrid.environment import Environment, EnvParams
from xminigrid.wrappers import DirectionObservationWrapper, GymAutoResetWrapper
# this will be default in new jax versions anyway
jax.config.update("jax_threefry_partitionable", True)
@dataclass
class TrainConfig:
project: str = "xminigrid"
group: str = "default"
name: str = "meta-task-ppo"
env_id: str = "XLand-MiniGrid-R1-9x9"
benchmark_id: str = "trivial-1m"
img_obs: bool = False
# agent
obs_emb_dim: int = 16
action_emb_dim: int = 16
rnn_hidden_dim: int = 1024
rnn_num_layers: int = 1
head_hidden_dim: int = 256
# training
enable_bf16: bool = False
num_envs: int = 8192
num_steps_per_env: int = 4096
num_steps_per_update: int = 32
update_epochs: int = 1
num_minibatches: int = 16
total_timesteps: int = 100_000_000
lr: float = 0.001
clip_eps: float = 0.2
gamma: float = 0.99
gae_lambda: float = 0.95
ent_coef: float = 0.01
vf_coef: float = 0.5
max_grad_norm: float = 0.5
eval_num_envs: int = 512
eval_num_episodes: int = 10
eval_seed: int = 42
train_seed: int = 42
checkpoint_path: Optional[str] = None
def __post_init__(self):
num_devices = jax.local_device_count()
# splitting computation across all available devices
self.num_envs_per_device = self.num_envs // num_devices
self.total_timesteps_per_device = self.total_timesteps // num_devices
self.eval_num_envs_per_device = self.eval_num_envs // num_devices
assert self.num_envs % num_devices == 0
self.num_meta_updates = round(
self.total_timesteps_per_device / (self.num_envs_per_device * self.num_steps_per_env)
)
self.num_inner_updates = self.num_steps_per_env // self.num_steps_per_update
assert self.num_steps_per_env % self.num_steps_per_update == 0
print(f"Num devices: {num_devices}, Num meta updates: {self.num_meta_updates}")
def make_states(config: TrainConfig):
# for learning rage scheduling
def linear_schedule(count):
total_inner_updates = config.num_minibatches * config.update_epochs * config.num_inner_updates
frac = 1.0 - (count // total_inner_updates) / config.num_meta_updates
return config.lr * frac
# setup environment
if "XLand" not in config.env_id:
raise ValueError("Only meta-task environments are supported.")
env, env_params = xminigrid.make(config.env_id)
env = GymAutoResetWrapper(env)
env = DirectionObservationWrapper(env)
# enabling image observations if needed
if config.img_obs:
from xminigrid.experimental.img_obs import RGBImgObservationWrapper
env = RGBImgObservationWrapper(env)
# loading benchmark
benchmark = xminigrid.load_benchmark(config.benchmark_id)
# set up training state
rng = jax.random.key(config.train_seed)
rng, _rng = jax.random.split(rng)
network = ActorCriticRNN(
num_actions=env.num_actions(env_params),
obs_emb_dim=config.obs_emb_dim,
action_emb_dim=config.action_emb_dim,
rnn_hidden_dim=config.rnn_hidden_dim,
rnn_num_layers=config.rnn_num_layers,
head_hidden_dim=config.head_hidden_dim,
img_obs=config.img_obs,
dtype=jnp.bfloat16 if config.enable_bf16 else None,
)
# [batch_size, seq_len, ...]
shapes = env.observation_shape(env_params)
init_obs = {
"obs_img": jnp.zeros((config.num_envs_per_device, 1, *shapes["img"])),
"obs_dir": jnp.zeros((config.num_envs_per_device, 1, shapes["direction"])),
"prev_action": jnp.zeros((config.num_envs_per_device, 1), dtype=jnp.int32),
"prev_reward": jnp.zeros((config.num_envs_per_device, 1)),
}
init_hstate = network.initialize_carry(batch_size=config.num_envs_per_device)
network_params = network.init(_rng, init_obs, init_hstate)
tx = optax.chain(
optax.clip_by_global_norm(config.max_grad_norm),
optax.inject_hyperparams(optax.adam)(learning_rate=linear_schedule, eps=1e-8), # eps=1e-5
)
train_state = TrainState.create(apply_fn=network.apply, params=network_params, tx=tx)
return rng, env, env_params, benchmark, init_hstate, train_state
def make_train(
env: Environment,
env_params: EnvParams,
benchmark: Benchmark,
config: TrainConfig,
):
@partial(jax.pmap, axis_name="devices")
def train(
rng: jax.Array,
train_state: TrainState,
init_hstate: jax.Array,
):
eval_hstate = init_hstate[0][None]
# META TRAIN LOOP
def _meta_step(meta_state, _):
rng, train_state = meta_state
# INIT ENV
rng, _rng1, _rng2 = jax.random.split(rng, num=3)
ruleset_rng = jax.random.split(rng, num=config.num_envs_per_device)
reset_rng = jax.random.split(rng, num=config.num_envs_per_device)
# sample rulesets for this meta update
rulesets = jax.vmap(benchmark.sample_ruleset)(ruleset_rng)
meta_env_params = env_params.replace(ruleset=rulesets)
timestep = jax.vmap(env.reset, in_axes=(0, 0))(meta_env_params, reset_rng)
prev_action = jnp.zeros(config.num_envs_per_device, dtype=jnp.int32)
prev_reward = jnp.zeros(config.num_envs_per_device)
# INNER TRAIN LOOP
def _update_step(runner_state, _):
# COLLECT TRAJECTORIES
def _env_step(runner_state, _):
rng, train_state, prev_timestep, prev_action, prev_reward, prev_hstate = runner_state
# SELECT ACTION
rng, _rng = jax.random.split(rng)
dist, value, hstate = train_state.apply_fn(
train_state.params,
{
# [batch_size, seq_len=1, ...]
"obs_img": prev_timestep.observation["img"][:, None],
"obs_dir": prev_timestep.observation["direction"][:, None],
"prev_action": prev_action[:, None],
"prev_reward": prev_reward[:, None],
},
prev_hstate,
)
action, log_prob = dist.sample_and_log_prob(seed=_rng)
# squeeze seq_len where possible
action, value, log_prob = action.squeeze(1), value.squeeze(1), log_prob.squeeze(1)
# STEP ENV
timestep = jax.vmap(env.step, in_axes=0)(meta_env_params, prev_timestep, action)
transition = Transition(
# ATTENTION: done is always false, as we optimize for entire meta-rollout
done=jnp.zeros_like(timestep.last()),
action=action,
value=value,
reward=timestep.reward,
log_prob=log_prob,
obs=prev_timestep.observation["img"],
dir=prev_timestep.observation["direction"],
prev_action=prev_action,
prev_reward=prev_reward,
)
runner_state = (rng, train_state, timestep, action, timestep.reward, hstate)
return runner_state, transition
initial_hstate = runner_state[-1]
# transitions: [seq_len, batch_size, ...]
runner_state, transitions = jax.lax.scan(_env_step, runner_state, None, config.num_steps_per_update)
# CALCULATE ADVANTAGE
rng, train_state, timestep, prev_action, prev_reward, hstate = runner_state
# calculate value of the last step for bootstrapping
_, last_val, _ = train_state.apply_fn(
train_state.params,
{
"obs_img": timestep.observation["img"][:, None],
"obs_dir": timestep.observation["direction"][:, None],
"prev_action": prev_action[:, None],
"prev_reward": prev_reward[:, None],
},
hstate,
)
advantages, targets = calculate_gae(transitions, last_val.squeeze(1), config.gamma, config.gae_lambda)
# UPDATE NETWORK
def _update_epoch(update_state, _):
def _update_minbatch(train_state, batch_info):
init_hstate, transitions, advantages, targets = batch_info
new_train_state, update_info = ppo_update_networks(
train_state=train_state,
transitions=transitions,
init_hstate=init_hstate.squeeze(1),
advantages=advantages,
targets=targets,
clip_eps=config.clip_eps,
vf_coef=config.vf_coef,
ent_coef=config.ent_coef,
)
return new_train_state, update_info
rng, train_state, init_hstate, transitions, advantages, targets = update_state
# MINIBATCHES PREPARATION
rng, _rng = jax.random.split(rng)
permutation = jax.random.permutation(_rng, config.num_envs_per_device)
# [seq_len, batch_size, ...]
batch = (init_hstate, transitions, advantages, targets)
# [batch_size, seq_len, ...], as our model assumes
batch = jtu.tree_map(lambda x: x.swapaxes(0, 1), batch)
shuffled_batch = jtu.tree_map(lambda x: jnp.take(x, permutation, axis=0), batch)
# [num_minibatches, minibatch_size, ...]
minibatches = jtu.tree_map(
lambda x: jnp.reshape(x, (config.num_minibatches, -1) + x.shape[1:]), shuffled_batch
)
train_state, update_info = jax.lax.scan(_update_minbatch, train_state, minibatches)
update_state = (rng, train_state, init_hstate, transitions, advantages, targets)
return update_state, update_info
# hstate shape: [seq_len=None, batch_size, num_layers, hidden_dim]
update_state = (rng, train_state, initial_hstate[None, :], transitions, advantages, targets)
update_state, loss_info = jax.lax.scan(_update_epoch, update_state, None, config.update_epochs)
# WARN: do not forget to get updated params
rng, train_state = update_state[:2]
# averaging over minibatches then over epochs
loss_info = jtu.tree_map(lambda x: x.mean(-1).mean(-1), loss_info)
runner_state = (rng, train_state, timestep, prev_action, prev_reward, hstate)
return runner_state, loss_info
# on each meta-update we reset rnn hidden to init_hstate
runner_state = (rng, train_state, timestep, prev_action, prev_reward, init_hstate)
runner_state, loss_info = jax.lax.scan(_update_step, runner_state, None, config.num_inner_updates)
# WARN: do not forget to get updated params
rng, train_state = runner_state[:2]
# EVALUATE AGENT
eval_ruleset_rng, eval_reset_rng = jax.random.split(jax.random.key(config.eval_seed))
eval_ruleset_rng = jax.random.split(eval_ruleset_rng, num=config.eval_num_envs_per_device)
eval_reset_rng = jax.random.split(eval_reset_rng, num=config.eval_num_envs_per_device)
eval_ruleset = jax.vmap(benchmark.sample_ruleset)(eval_ruleset_rng)
eval_env_params = env_params.replace(ruleset=eval_ruleset)
eval_stats = jax.vmap(rollout, in_axes=(0, None, 0, None, None, None))(
eval_reset_rng,
env,
eval_env_params,
train_state,
eval_hstate,
config.eval_num_episodes,
)
eval_stats = jax.lax.pmean(eval_stats, axis_name="devices")
# averaging over inner updates, adding evaluation metrics
loss_info = jtu.tree_map(lambda x: x.mean(-1), loss_info)
loss_info.update(
{
"eval/returns_mean": eval_stats.reward.mean(0),
"eval/returns_median": jnp.median(eval_stats.reward),
"eval/lengths": eval_stats.length.mean(0),
"eval/lengths_20percentile": jnp.percentile(eval_stats.length, q=20),
"eval/returns_20percentile": jnp.percentile(eval_stats.reward, q=20),
"lr": train_state.opt_state[-1].hyperparams["learning_rate"],
}
)
meta_state = (rng, train_state)
return meta_state, loss_info
meta_state = (rng, train_state)
meta_state, loss_info = jax.lax.scan(_meta_step, meta_state, None, config.num_meta_updates)
return {"state": meta_state[-1], "loss_info": loss_info}
return train
@pyrallis.wrap()
def train(config: TrainConfig):
# logging to wandb
run = wandb.init(
project=config.project,
group=config.group,
name=config.name,
config=asdict(config),
save_code=True,
)
# removing existing checkpoints if any
if config.checkpoint_path is not None and os.path.exists(config.checkpoint_path):
shutil.rmtree(config.checkpoint_path)
rng, env, env_params, benchmark, init_hstate, train_state = make_states(config)
# replicating args across devices
rng = jax.random.split(rng, num=jax.local_device_count())
train_state = replicate(train_state, jax.local_devices())
init_hstate = replicate(init_hstate, jax.local_devices())
print("Compiling...")
t = time.time()
train_fn = make_train(env, env_params, benchmark, config)
train_fn = train_fn.lower(rng, train_state, init_hstate).compile()
elapsed_time = time.time() - t
print(f"Done in {elapsed_time:.2f}s.")
print("Training...")
t = time.time()
train_info = jax.block_until_ready(train_fn(rng, train_state, init_hstate))
elapsed_time = time.time() - t
print(f"Done in {elapsed_time:.2f}s.")
print("Logginig...")
loss_info = unreplicate(train_info["loss_info"])
total_transitions = 0
for i in range(config.num_meta_updates):
total_transitions += config.num_steps_per_env * config.num_envs_per_device * jax.local_device_count()
info = jtu.tree_map(lambda x: x[i].item(), loss_info)
info["transitions"] = total_transitions
wandb.log(info)
run.summary["training_time"] = elapsed_time
run.summary["steps_per_second"] = (config.total_timesteps_per_device * jax.local_device_count()) / elapsed_time
if config.checkpoint_path is not None:
checkpoint = {"config": asdict(config), "params": unreplicate(train_info)["state"].params}
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(checkpoint)
orbax_checkpointer.save(config.checkpoint_path, checkpoint, save_args=save_args)
print("Final return: ", float(loss_info["eval/returns_mean"][-1]))
run.finish()
if __name__ == "__main__":
train()