-
Notifications
You must be signed in to change notification settings - Fork 823
/
train.py
141 lines (119 loc) · 4.31 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import gym
import numpy as np
import parl
import argparse
from parl.utils import logger, ReplayMemory
from cartpole_model import CartpoleModel
from cartpole_agent import CartpoleAgent
from parl.env import CompatWrapper, is_gym_version_ge
from parl.algorithms import DQN
LEARN_FREQ = 5 # training frequency
MEMORY_SIZE = 200000
MEMORY_WARMUP_SIZE = 200
BATCH_SIZE = 64
LEARNING_RATE = 0.0005
GAMMA = 0.99
# train an episode
def run_train_episode(agent, env, rpm):
total_reward = 0
obs = env.reset()
step = 0
while True:
step += 1
action = agent.sample(obs)
next_obs, reward, done, _ = env.step(action)
rpm.append(obs, action, reward, next_obs, done)
# train model
if (len(rpm) > MEMORY_WARMUP_SIZE) and (step % LEARN_FREQ == 0):
# s,a,r,s',done
(batch_obs, batch_action, batch_reward, batch_next_obs,
batch_done) = rpm.sample_batch(BATCH_SIZE)
train_loss = agent.learn(batch_obs, batch_action, batch_reward,
batch_next_obs, batch_done)
total_reward += reward
obs = next_obs
if done:
break
return total_reward
# evaluate 5 episodes
def run_evaluate_episodes(agent, eval_episodes=5, render=False):
# Compatible for different versions of gym
if is_gym_version_ge("0.26.0") and render: # if gym version >= 0.26.0
env = gym.make('CartPole-v1', render_mode="human")
else:
env = gym.make('CartPole-v1')
env = CompatWrapper(env)
eval_reward = []
for i in range(eval_episodes):
obs = env.reset()
episode_reward = 0
while True:
action = agent.predict(obs)
obs, reward, done, _ = env.step(action)
episode_reward += reward
if render:
env.render()
if done:
break
eval_reward.append(episode_reward)
return np.mean(eval_reward)
def main():
env = gym.make('CartPole-v0')
# Compatible for different versions of gym
env = CompatWrapper(env)
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.n
logger.info('obs_dim {}, act_dim {}'.format(obs_dim, act_dim))
# set action_shape = 0 while in discrete control environment
rpm = ReplayMemory(MEMORY_SIZE, obs_dim, 0)
# build an agent
model = CartpoleModel(obs_dim=obs_dim, act_dim=act_dim)
alg = DQN(model, gamma=GAMMA, lr=LEARNING_RATE)
agent = CartpoleAgent(
alg, act_dim=act_dim, e_greed=0.1, e_greed_decrement=1e-6)
# warmup memory
while len(rpm) < MEMORY_WARMUP_SIZE:
run_train_episode(agent, env, rpm)
max_episode = args.max_episode
# start training
episode = 0
while episode < max_episode:
# train part
for i in range(50):
total_reward = run_train_episode(agent, env, rpm)
episode += 1
# test part
eval_reward = run_evaluate_episodes(agent, render=False)
logger.info('episode:{} e_greed:{} Test reward:{}'.format(
episode, agent.e_greed, eval_reward))
# save the parameters to ./model.ckpt
save_path = './model.ckpt'
agent.save(save_path)
# save the model and parameters of policy network for inference
save_inference_path = './inference_model'
input_shapes = [[None, env.observation_space.shape[0]]]
input_dtypes = ['float32']
agent.save_inference_model(save_inference_path, input_shapes, input_dtypes)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--max_episode',
type=int,
default=800,
help='stop condition: number of max episode')
args = parser.parse_args()
main()