-
Notifications
You must be signed in to change notification settings - Fork 7
/
optimizer.py
816 lines (669 loc) · 32 KB
/
optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
from collections import Counter
from collections import deque
from datetime import datetime
import argparse
import copy
import gc
import io
import logging
import os
import pickle
import re
import socket
import time
import random
from google.cloud import storage
from tensorboardX import SummaryWriter
import numpy as np
import pika
from scipy.signal import lfilter
import torch
import torch.distributed as dist
from distributed import DistributedDataParallelSparseParamCPU
from dotaservice.protos.DotaService_pb2 import TEAM_DIRE, TEAM_RADIANT
from policy import Policy
from policy import REWARD_KEYS
logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# torch.set_printoptions(profile="full")
torch.manual_seed(7)
random.seed(7)
np.random.seed(7)
eps = np.finfo(np.float32).eps.item()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
def is_distributed():
return torch.distributed.is_available() and torch.distributed.is_initialized()
def is_master():
if is_distributed():
return torch.distributed.get_rank() == 0
else:
return True
def discount(x, gamma):
return lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1].astype(np.float32)
def advantage_returns(rewards, values, gamma, lam):
"""Compute the advantage and returns from rewards and values."""
# GAE-Lambda advantage calculation.
deltas = rewards[:-1] + gamma * values[1:] - values[:-1]
advantages = discount(deltas, gamma * lam)
# Compute rewards-to-go (targets for the value function).
returns = discount(rewards, gamma)[:-1]
return advantages, returns
class MessageQueue:
EXPERIENCE_QUEUE_NAME = 'experience'
MODEL_EXCHANGE_NAME = 'model'
MAX_RETRIES = 10
def __init__(self, host, port, prefetch_count, use_model_exchange):
"""
Args:
prefetch_count (int): Amount of messages to prefetch. Settings this variable too
high can result in blocked pipes that time out.
"""
self._params = pika.ConnectionParameters(
host=host,
port=port,
heartbeat=300,
)
self.prefetch_count = prefetch_count
self.use_model_exchange = use_model_exchange
self._conn = None
self._xp_channel = None
self._model_exchange = None
def process_events(self):
try:
self._conn.process_data_events()
except:
pass
def connect(self):
if not self._conn or self._conn.is_closed:
# RMQ.
for i in range(10):
try:
self._conn = pika.BlockingConnection(self._params)
except pika.exceptions.ConnectionClosed:
logger.error('Connection to RMQ failed. retring. ({}/{})'.format(i, self.MAX_RETRIES))
time.sleep(5)
continue
else:
logger.info('Connected to RMQ')
break
# Experience channel.
self._xp_channel = self._conn.channel()
self._xp_channel.basic_qos(prefetch_count=self.prefetch_count)
self._xp_channel.queue_declare(queue=self.EXPERIENCE_QUEUE_NAME)
# Model Exchange.
if self.use_model_exchange:
self._model_exchange = self._conn.channel()
self._model_exchange.exchange_declare(
exchange=self.MODEL_EXCHANGE_NAME,
exchange_type='x-recent-history',
arguments={'x-recent-history-length': 1},
)
@property
def xp_queue_size(self):
try:
res = self._xp_channel.queue_declare(queue=self.EXPERIENCE_QUEUE_NAME, passive=True)
return res.method.message_count
except:
return None
def process_data_events(self):
# Sends heartbeat, might keep conn healthier.
try:
self._conn.process_data_events()
except: # Gotta catch em' all!
pass
def _publish_model(self, msg, hdr):
self._model_exchange.basic_publish(
exchange=self.MODEL_EXCHANGE_NAME,
routing_key='',
body=msg,
properties=pika.BasicProperties(headers=hdr),
)
def publish_model(self, *args, **kwargs):
try:
self._publish_model(*args, **kwargs)
except (pika.exceptions.ConnectionClosed, pika.exceptions.ChannelClosed):
logger.error('reconnecting to queue')
self.connect()
self._publish_model(*args, **kwargs)
def _consume_xp(self):
method, properties, body = next(self._xp_channel.consume(
queue=self.EXPERIENCE_QUEUE_NAME,
no_ack=False,
))
self._xp_channel.basic_ack(delivery_tag=method.delivery_tag)
return method, properties, body
def consume_xp(self):
try:
return self._consume_xp()
except (pika.exceptions.ConnectionClosed, pika.exceptions.ChannelClosed):
logger.error('reconnecting to queue')
self.connect()
return self._consume_xp()
def close(self):
if self._conn and self._conn.is_open:
logger.info('closing queue connection')
self._conn.close()
class Sequence:
def __init__(self, game_id, weight_version, team_id, observations, actions, masks, values, rewards, hidden, log_probs_sel):
self.game_id = game_id
self.weight_version = weight_version
self.team_id = team_id
self.observations = observations
self.actions = actions
self.masks = masks
self.rewards = rewards
self.values = values
self.hidden = hidden
self.log_probs_sel = log_probs_sel
# Below are to be assigned later.
self.advantages = None
self.returns = None
def all_gather(t):
_t = [torch.empty_like(t) for _ in range(dist.get_world_size())]
dist.all_gather(_t, t)
return torch.cat(_t)
class DotaOptimizer:
MODEL_FILENAME_FMT = "model_%09d.pt"
BUCKET_NAME = 'dotaservice'
MODEL_HISTOGRAM_FREQ = 128
MAX_GRAD_NORM = 0.5
SPEED_KEY = 'steps per s'
def __init__(self, rmq_host, rmq_port, epochs, min_seq_per_epoch, seq_len,
learning_rate, checkpoint, pretrained_model, mq_prefetch_count, log_dir,
entropy_coef, vf_coef, run_local):
super().__init__()
self.rmq_host = rmq_host
self.rmq_port = rmq_port
self.epochs = epochs
self.min_seq_per_epoch = min_seq_per_epoch
self.seq_len = seq_len
self.learning_rate = learning_rate
self.checkpoint = checkpoint
self.mq_prefetch_count = mq_prefetch_count
self.iteration_start = 1
self.policy_base = Policy()
self.log_dir = log_dir
self.entropy_coef = entropy_coef
self.vf_coef = vf_coef
self.run_local = run_local
self.iterations = 100000
self.model_upload_freq = 10
self.eventfile_refresh_freq = 100
self.e_clip = 0.1
if self.checkpoint:
self.writer = None
logger.info('Checkpointing to: {}'.format(self.log_dir))
try:
os.mkdir(self.log_dir)
except:
pass
# if we are not running locally, set bucket
if not self.run_local:
client = storage.Client()
self.bucket = client.get_bucket(self.BUCKET_NAME)
# First, check if logdir exists.
latest_model = self.get_latest_model(prefix=self.log_dir)
# If there's a model in here, we resume from there
if latest_model is not None:
logger.info('Found a latest model in pretrained dir: {}'.format(latest_model))
if pretrained_model is not None:
logger.warning('Overriding pretrained model by latest model.')
pretrained_model = latest_model
if pretrained_model is not None:
self.iteration_start = self.iteration_from_model_filename(filename=pretrained_model) + 1
# if we are not running locally, pull down model
if not self.run_local and pretrained_model is not None:
logger.info('Downloading: {}'.format(pretrained_model))
model_blob = self.bucket.get_blob(pretrained_model)
# TODO(tzaman): Download to BytesIO and supply to torch in that way.
pretrained_model = '/tmp/model.pt'
model_blob.download_to_filename(pretrained_model)
if pretrained_model is not None:
self.policy_base.load_state_dict(torch.load(pretrained_model,
map_location=torch.device(device)),
strict=False)
if torch.distributed.is_available() and torch.distributed.is_initialized():
self.policy = DistributedDataParallelSparseParamCPU(self.policy_base)
else:
self.policy = self.policy_base
self.policy.to(device)
self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=self.learning_rate)
self.time_last_it = time.time()
self.mq = MessageQueue(host=self.rmq_host, port=self.rmq_port,
prefetch_count=mq_prefetch_count,
use_model_exchange=self.checkpoint)
self.mq.connect()
# Upload initial model before any step is taken, and only if we're not resuming.
self.upload_model(version=self.iteration_start)
@staticmethod
def iteration_from_model_filename(filename):
x = re.search('(\d+)(?=.pt)', filename)
return int(x.group(0))
def get_latest_model(self, prefix):
if not self.run_local:
blobs = list(self.bucket.list_blobs(prefix=prefix))
else:
blobs = [f for f in os.listdir(prefix) if os.path.isfile(f)]
if not blobs:
# Directory does not exist, or no files in directory.
return None
else:
fns = [x.name for x in blobs if x.name[-3:] == '.pt']
if not fns:
# No relevant files in directory.
return None
fns.sort()
latest_model = fns[-1]
return latest_model
@property
def events_filename(self):
return self.writer.file_writer.event_writer._ev_writer._file_name
def get_rollout(self):
# TODO(tzaman): make a rollout object
method, properties, body = self.mq.consume_xp()
data = pickle.loads(body)
reward_sum = np.sum(data['rewards'])
rollout_len = data['rewards'].shape[0]
version = data['weight_version']
canvas = data['canvas']
# Compute rewards per topic, reduce-sum down the sequences.
subrewards = data['rewards'].sum(axis=0)
return data, subrewards, rollout_len, version, canvas
def experiences_from_rollout(self, data):
# TODO(tzaman): The rollout can consist out of multiple viable sequences.
# These should be padded and then sliced into separate experiences.
observations = data['observations']
masks = data['masks']
actions = data['actions']
rewards = data['rewards']
rollout_len = data['rewards'].shape[0]
logger.debug('rollout_len={}'.format(rollout_len))
sequences = []
hidden = self.policy.init_hidden().to(device)
values = []
rewards_sum = []
slice_indices = range(0, rollout_len, self.seq_len)
for i1 in slice_indices:
pad = 0
# Check if this slice requires padding.
if rollout_len - i1 < self.seq_len:
pad = self.seq_len - (rollout_len - i1)
i2 = i1 + self.seq_len - pad
logger.debug('Slice[{}:{}], pad={}'.format(i1, i2, pad))
# Slice out the relevant parts.
s_observations = {}
for key, val in observations.items():
s_observations[key] = val[i1:i2, :].to(device)
s_masks = {}
for key, val in masks.items():
s_masks[key] = val[i1:i2, :].to(device)
s_actions = {}
for key, val in actions.items():
s_actions[key] = val[i1:i2, :].to(device)
s_rewards = rewards[i1:i2]
if pad:
dim_pad = {
1: (0, pad),
2: (0, 0, 0, pad),
3: (0, 0, 0, 0, 0, pad),
}
for key, val in s_observations.items():
s_observations[key] = torch.nn.functional.pad(val, dim_pad[val.dim()], mode='constant', value=0).detach()
for key, val in s_masks.items():
s_masks[key] = torch.nn.functional.pad(val, pad=dim_pad[val.dim()], mode='constant', value=0).detach()
for key, val in s_actions.items():
s_actions[key] = torch.nn.functional.pad(val, pad=dim_pad[val.dim()], mode='constant', value=0).detach()
s_rewards = np.pad(s_rewards, ((0, pad), (0, 0)), mode='constant')
input_hidden = hidden
head_logits_dict, s_values, hidden = self.policy.sequence(**s_observations, hidden=input_hidden)
log_probs_sel = {}
for key in head_logits_dict:
log_probs = Policy.masked_softmax(logits=head_logits_dict[key], mask=s_masks[key].unsqueeze(0))
log_probs_sel[key] = torch.masked_select(input=log_probs, mask=s_actions[key]).detach()
# The values and rewards are gathered here over all sequences, because the values are
# cumulative, and therefore need the first step from each next sequence. To optimize this,
# we gather them here, and process these after the loop, and add them to the Sequence
# object later.
values.append(s_values)
rewards_sum.append(np.sum(s_rewards, axis=1).ravel())
sequence = Sequence(
game_id=data['game_id'],
weight_version=data['weight_version'],
team_id=data['team_id'],
observations=s_observations,
actions=s_actions,
masks=s_masks,
values=s_values.detach(),
rewards=s_rewards,
hidden=input_hidden.detach(),
log_probs_sel=log_probs_sel,
)
sequences.append(sequence)
# TODO(tzaman): For now, we assume we are always presented with full and terminated sequences.
# This is why we append a zero here, so the advantage and return computation works efficiently.
# In the future, if we support receiving non-terminated sequences, we need one more state and
# reward than we have action.
values = torch.cat(values).cpu().numpy().ravel()
values = np.append(values, np.array(0., dtype=np.float32))
rewards_sum = np.concatenate(rewards_sum)
rewards_sum = np.append(rewards_sum, np.array(0., dtype=np.float32))
advantages, returns = advantage_returns(rewards=rewards_sum, values=values, gamma=0.98, lam=0.97)
# Split the advantages and returns back up into their respective sequences.
advantages = np.split(advantages, len(sequences))
returns = np.split(returns, len(sequences))
for s, a, r in zip(sequences, advantages, returns):
s.advantages = torch.from_numpy(a)
s.returns = torch.from_numpy(r)
return sequences
@staticmethod
def list_of_dicts_to_dict_of_lists(x):
return {k: torch.stack([d[k] for d in x]) for k in x[0]}
def run(self):
for it in range(self.iteration_start, self.iterations):
logger.info('iteration {}/{}'.format(it, self.iterations))
# First grab a bunch of experiences
experiences = []
subrewards = []
rollout_lens = []
weight_ages = []
canvas = None # Just save only the last canvas.
start_xp = time.time()
xp_waits = 0
while len(experiences) < self.min_seq_per_epoch:
logger.debug(' adding experience @{}/{}'.format(len(experiences), self.min_seq_per_epoch))
# Get new experiences from a new rollout.
with torch.no_grad():
start_xp_wait = time.time()
rollout, rollout_subrewards, rollout_len, weight_version, canvas = self.get_rollout()
xp_waits += time.time() - start_xp_wait
rollout_experiences = self.experiences_from_rollout(data=rollout)
experiences.extend(rollout_experiences)
subrewards.append(rollout_subrewards)
rollout_lens.append(rollout_len)
weight_ages.append(it - weight_version)
time_xp = time.time() - start_xp
losses = []
entropies = []
grad_norms = []
start_optimizing = time.time()
for ep in range(self.epochs):
logger.info(' epoch {}/{}'.format(ep + 1, self.epochs))
self.mq.process_data_events()
loss_d, entropy_d, grad_norm_d = self.train(experiences=experiences)
losses.append(loss_d)
entropies.append(entropy_d)
grad_norms.append(grad_norm_d)
time_optimizing = time.time() - start_optimizing
losses = self.list_of_dicts_to_dict_of_lists(losses)
loss = losses['loss'].mean()
entropies = self.list_of_dicts_to_dict_of_lists(entropies)
entropy = torch.stack(list(entropies.values())).sum(dim=0).mean()
grad_norms = self.list_of_dicts_to_dict_of_lists(grad_norms)
n_steps = len(experiences) * self.seq_len
subrewards_per_sec = np.stack(subrewards) / n_steps * Policy.OBSERVATIONS_PER_SECOND
rollout_rewards = subrewards_per_sec.sum(axis=1)
reward_dict = dict(zip(REWARD_KEYS, subrewards_per_sec.sum(axis=0)))
reward_per_sec = rollout_rewards.sum()
rollout_lens = torch.tensor(rollout_lens, dtype=torch.float32)
avg_rollout_len = rollout_lens.mean()
weight_ages = torch.tensor(weight_ages, dtype=torch.float32)
avg_weight_age = weight_ages.mean()
time_it = time.time() - self.time_last_it
steps_per_s = n_steps / (time_it)
self.time_last_it = time.time()
metrics = {
self.SPEED_KEY: steps_per_s,
'reward_per_sec/sum': reward_per_sec,
'loss/sum': loss,
'loss/policy': losses['policy_loss'].mean(),
'loss/entropy': losses['entropy_loss'].mean(),
'loss/value': losses['value_loss'].mean(),
'entropy': entropy,
'avg_rollout_len': avg_rollout_len,
'avg_weight_age': avg_weight_age,
}
metrics['timing/it'] = time_it # Full total time since last step
metrics['timing/xp_total'] = time_xp
metrics['timing/xp_mq_wait'] = xp_waits
metrics['timing/optimizer'] = time_optimizing
for k, v in entropies.items():
metrics['entropy/{}'.format(k)] = v.mean()
for k, v in grad_norms.items():
metrics['grad_norm/{}'.format(k)] = v.mean()
for k, v in reward_dict.items():
metrics['reward_per_sec/{}'.format(k)] = v
logger.info('steps_per_s={:.2f}, avg_weight_age={:.1f}, reward_per_sec={:.4f}, loss={:.4f}, entropy={:.3f}'.format(
steps_per_s, float(avg_weight_age), reward_per_sec, float(loss), float(entropy)))
if self.checkpoint:
start_checkpoint = time.time()
# TODO(tzaman): re-introduce distributed metrics. See commits from december 2017.
if self.writer is None or it % self.eventfile_refresh_freq == 0:
logger.info('(Re-)Creating TensorBoard eventsfile (#iteration={})'.format(it))
self.writer = SummaryWriter(log_dir=self.log_dir)
# Write metrics to events file.
for name, metric in metrics.items():
self.writer.add_scalar(name, metric, it)
# TODO(tzaman): How to add the time spent on writing events file and model?
# Add per-iteration histograms
self.writer.add_histogram('losses', losses['loss'], it)
# self.writer.add_histogram('entropies', entropies, it)
self.writer.add_histogram('rollout_lens', rollout_lens, it)
self.writer.add_histogram('weight_age', weight_ages, it)
# Rewards histogram
self.writer.add_histogram('rewards_per_sec_per_rollout', rollout_rewards, it)
# Model
if it % self.MODEL_HISTOGRAM_FREQ == 1:
for name, param in self.policy_base.named_parameters():
self.writer.add_histogram('param/' + name, param.clone().cpu().data.numpy(), it)
self.writer.add_image('canvas', canvas, it, dataformats='HWC')
# RMQ Queue size.
queue_size = self.mq.xp_queue_size
if queue_size is not None:
self.writer.add_scalar('mq_size', queue_size, it)
# Upload events to GCS
self.writer.file_writer.flush() # Flush before uploading
if not self.run_local:
blob = self.bucket.blob(self.events_filename)
blob.upload_from_filename(filename=self.events_filename)
time_checkpoint = time.time() - start_checkpoint
start_model_upload = time.time()
self.upload_model(version=it)
time_model_upload = time.time() - start_model_upload
logger.info('Timings: it={:.2f}, xp={:.2f}, xp_wait={:.2f} opt={:.2f}, upload_tb={:.2f}, upload_model={:.2f}'.format(
time_it, time_xp, xp_waits, time_optimizing, time_checkpoint, time_model_upload))
def train(self, experiences):
# Train on one epoch of data.
# Experiences is a list of (padded) experience chunks.
logger.debug('train(experiences=#{})'.format(len(experiences)))
# Stack together all experiences.
advantage = torch.stack([e.advantages for e in experiences]).to(device)
advantage = (advantage - advantage.mean()) / (advantage.std() + eps)
advantage = advantage.detach()
returns = torch.stack([e.returns for e in experiences]).detach().to(device)
hidden = torch.cat([e.hidden for e in experiences], dim=1).detach()
# The action mask contains the mask of the selected actions
actions = {key: [] for key in Policy.OUTPUT_KEYS}
for e in experiences:
for key, val in e.actions.items():
actions[key].append(val)
for key in actions:
actions[key] = torch.stack(actions[key])
# The head mask contains the mask of the relevant heads, where a selection has taken place,
# and includes only valid possible selections from those heads.
masks = {key: [] for key in Policy.OUTPUT_KEYS}
for e in experiences:
for key, val in e.masks.items():
masks[key].append(val)
for key, val in masks.items():
masks[key] = torch.stack(val)
observations = {key: [] for key in Policy.INPUT_KEYS}
for e in experiences:
for key, val in e.observations.items():
observations[key].append(val)
for key, val in observations.items():
observations[key] = torch.stack(val)
# Notice there is no notion of loss masking here, this is unnessecary as we only work
# use selected probabilties. E.g. when things were padded, nothing was selected, so no data.
head_logits_dict, values, _ = self.policy(**observations, hidden=hidden)
# Perform a masked softmax
policy_loss = {}
entropies = {}
for key in head_logits_dict:
# actions_step contains if we took an action of this key during the respective step
actions_step = actions[key].sum(dim=-1) != 0
if actions_step.sum() == 0:
policy_loss[key] = torch.zeros([])
entropies[key] = torch.zeros([])
continue
log_probs = Policy.masked_softmax(logits=head_logits_dict[key], mask=masks[key])
log_probs_sel = torch.masked_select(input=log_probs, mask=actions[key])
old_log_probs_sel = torch.cat([e.log_probs_sel[key] for e in experiences]).detach()
advantage_sel = advantage[actions_step]
# PPO
ratio = torch.exp(log_probs_sel - old_log_probs_sel)
surr1 = ratio * advantage_sel.view(-1)
surr2 = torch.clamp(ratio, 1.0 - self.e_clip, 1.0 + self.e_clip) * advantage_sel
policy_loss[key] = -torch.min(surr1, surr2).mean()
n_actions = actions_step.sum() # Amount of actions being taken for this key this batch.
log_probs_masked = torch.masked_select(input=log_probs, mask=masks[key])
probs_sel_masked = torch.exp(log_probs_masked)
entropies[key] = -(probs_sel_masked * log_probs_masked).sum() / n_actions
# Grab all the policy losses
policy_loss = torch.stack([v for v in policy_loss.values()])
policy_loss = policy_loss.mean()
if self.entropy_coef > 0:
entropy = torch.stack(list(entropies.values())).sum()
entropy_loss = -self.entropy_coef * entropy
else:
entropy_loss = torch.tensor(0.)
if self.vf_coef > 0:
# Notice we don't have to remove zero-padded entries, as they give 0 loss.
value_loss = 0.5 * (returns - values.squeeze(-1)).pow(2).mean()
value_loss = self.vf_coef * value_loss
else:
value_loss = torch.tensor(0.)
loss = policy_loss + entropy_loss + value_loss
if torch.isnan(loss):
raise ValueError('loss={}, policy_loss={}, entropy_loss={}, value_loss={}'.format(
loss, policy_loss, entropy_loss, value_loss))
self.optimizer.zero_grad()
loss.backward()
grad_norm = self.mean_gradient_norm()
torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.MAX_GRAD_NORM)
grad_norm_clipped = self.mean_gradient_norm()
if torch.isnan(grad_norm):
raise ValueError('grad_norm={}'.format(grad_norm))
self.optimizer.step()
losses = {
'loss': loss,
'policy_loss': policy_loss,
'entropy_loss': entropy_loss,
'value_loss': value_loss,
}
return losses, entropies, {'unclipped': grad_norm, 'clipped': grad_norm_clipped}
def mean_gradient_norm(self):
gs = []
for p in list(filter(lambda p: p.grad is not None, self.policy.parameters())):
gs.append(p.grad.data.norm(2))
return torch.stack(gs).mean()
def upload_model(self, version):
if not is_master():
# Only rank 0 uploads the model.
return
filename = self.MODEL_FILENAME_FMT % version
rel_path = os.path.join(self.log_dir, filename)
# Serialize the model.
buffer = io.BytesIO()
state_dict = self.policy_base.state_dict()
torch.save(obj=state_dict, f=buffer)
state_dict_b = buffer.getvalue()
# Write model to file.
with open(rel_path, 'wb') as f:
f.write(state_dict_b)
# Send to exchange.
self.mq.publish_model(msg=state_dict_b, hdr={'version': version})
# Upload to GCP.
if version % self.model_upload_freq == 0:
logger.info('Uploading model {} to GCP'.format(version))
if not self.run_local:
blob = self.bucket.blob(rel_path)
blob.upload_from_string(data=state_dict_b) # Model
def init_distribution(backend='gloo'):
logger.info('init_distribution')
assert 'WORLD_SIZE' in os.environ
world_size = int(os.environ['WORLD_SIZE'])
if world_size < 2:
logger.warning('skipping distribution: world size too small ({})'.format(world_size))
return
torch.distributed.init_process_group(backend=backend)
logger.info("Distribution initialized.")
def main(rmq_host, rmq_port, epochs, min_seq_per_epoch, seq_len, learning_rate,
pretrained_model, mq_prefetch_count, log_dir, entropy_coef, vf_coef, run_local):
logger.info('main(rmq_host={}, rmq_port={}, epochs={}, min_seq_per_epoch={},'
' seq_len={}, learning_rate={}, pretrained_model={}, mq_prefetch_count={}, entropy_coef={}, vf_coef={})'.format(
rmq_host, rmq_port, epochs, min_seq_per_epoch, seq_len, learning_rate, pretrained_model, mq_prefetch_count,
entropy_coef, vf_coef))
# If applicable, initialize distributed training.
if torch.distributed.is_available():
init_distribution()
else:
logger.info('distribution unavailable')
# Only the master should checkpoint.
checkpoint = is_master()
dota_optimizer = DotaOptimizer(
rmq_host=rmq_host,
rmq_port=rmq_port,
epochs=epochs,
min_seq_per_epoch=min_seq_per_epoch,
seq_len=seq_len,
learning_rate=learning_rate,
checkpoint=checkpoint,
pretrained_model=pretrained_model,
mq_prefetch_count=mq_prefetch_count,
log_dir=log_dir,
entropy_coef=entropy_coef,
vf_coef=vf_coef,
run_local=run_local,
)
dota_optimizer.run()
def default_log_dir():
return '{}_{}'.format(datetime.now().strftime('%b%d_%H-%M-%S'), socket.gethostname())
if __name__ == '__main__':
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--log-dir", type=str, help="log and job dir name", default=default_log_dir())
parser.add_argument("--ip", type=str, help="mq ip", default='127.0.0.1')
parser.add_argument("--port", type=int, help="mq port", default=5672)
parser.add_argument("--epochs", type=int, help="amount of epochs", default=4)
parser.add_argument("--min-seq-per-epoch", type=int, help="minimum amount of sequences per epoch."
"This can be slightly more because we want to process full rollouts.", default=1024)
parser.add_argument("--seq-len", type=int, help="sequence length (as one sample in a minibatch)."
"This is also the length that will be (truncated) backpropped into.", default=16)
parser.add_argument("--learning-rate", type=float, help="learning rate", default=5e-5)
parser.add_argument("--entropy-coef", type=float, help="entropy coef (as proportional addition to the loss)", default=5e-4)
parser.add_argument("--vf-coef", type=float, help="value fn coef (as proportional addition to the loss)", default=0.5)
parser.add_argument("--pretrained-model", type=str, help="pretrained model file within gcs bucket", default=None)
parser.add_argument("--mq-prefetch-count", type=int,
help="amount of experience messages to prefetch from mq", default=1)
parser.add_argument("-l", "--log", dest="log_level", help="Set the logging level",
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], default='INFO')
parser.add_argument("--run-local", type=bool, help="set to true to run locally (not using GCP)", default=False)
args = parser.parse_args()
logger.setLevel(args.log_level)
try:
main(
rmq_host=args.ip,
rmq_port=args.port,
epochs=args.epochs,
min_seq_per_epoch=args.min_seq_per_epoch,
seq_len=args.seq_len,
learning_rate=args.learning_rate,
pretrained_model=args.pretrained_model,
mq_prefetch_count=args.mq_prefetch_count,
log_dir=args.log_dir,
entropy_coef=args.entropy_coef,
vf_coef=args.vf_coef,
run_local=args.run_local,
)
except KeyboardInterrupt:
pass