-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathworker_train_distributed.py
86 lines (68 loc) · 3.09 KB
/
worker_train_distributed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from agents.td_leaf_agent import TDLeafAgent
from envs.chess import ChessEnv
from multiprocessing import Process
import tensorflow as tf
from value_model import ValueModel
import argparse
def work(env, job_name, task_index, cluster, log_dir, verbose):
server = tf.train.Server(cluster,
job_name=job_name,
task_index=task_index)
if job_name == "ps":
server.join()
else:
with tf.device(tf.train.replica_device_setter(
worker_device="/job:" + job_name + "/task:%d" % task_index,
cluster=cluster)):
with tf.device("/job:worker/task:%d/cpu:0" % task_index):
with tf.variable_scope('local'):
local_network = ValueModel(is_local=True)
network = ValueModel()
worker_name = 'worker_%03d' % task_index
agent = TDLeafAgent(worker_name,
network,
local_network,
env,
verbose=verbose)
summary_op = tf.summary.merge_all()
scaffold = tf.train.Scaffold(summary_op=summary_op)
with tf.train.MonitoredTrainingSession(master=server.target,
is_chief=False,
checkpoint_dir=log_dir,
save_summaries_steps=1,
scaffold=scaffold) as sess:
agent.sess = sess
while not sess.should_stop():
episode_number = sess.run(agent.increment_train_episode_count)
reward = agent.train(num_moves=10, depth=3, pre_train=False)
if agent.verbose:
print(worker_name,
"EPISODE:", episode_number,
"UPDATE:", sess.run(agent.update_count),
"REWARD:", reward)
print('-' * 100)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("run_name")
parser.add_argument("chief_ip")
parser.add_argument("worker_ip")
parser.add_argument("tester_ip")
args = parser.parse_args()
ps_hosts = [args.chief_ip + ':' + str(2222 + i) for i in range(5)]
chief_trainer_hosts = [args.chief_ip + ':' + str(3333 + i) for i in range(35)]
worker_trainer_hosts = [args.worker_ip + ':' + str(3333 + i) for i in range(35)]
tester_hosts = [args.tester_ip + ':' + str(3333 + i) for i in range(35)]
ckpt_dir = "./log/" + args.run_name
cluster_spec = tf.train.ClusterSpec(
{"ps": ps_hosts,
"worker": chief_trainer_hosts + worker_trainer_hosts,
"tester": tester_hosts}
)
processes = []
for task_idx, _ in enumerate(worker_trainer_hosts):
env = ChessEnv()
p = Process(target=work, args=(env, 'worker', task_idx + len(chief_trainer_hosts), cluster_spec, ckpt_dir, 1))
processes.append(p)
p.start()
for process in processes:
process.join()