-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_dqn.py
106 lines (94 loc) · 3.66 KB
/
run_dqn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""
Implementation of Asynchronous Advantage Actor-Critic (A3C) with Generalized Advantage Estimation (GAE)
This module is designed to train a model for standard Gym environments.
Copyright: Pavel B. Chernov, [email protected]
Date: Dec 2020 - July 2021
License: MIT
"""
from init import *
from optim.models import Models
from optim.dqn import DQN
from optim.scheduler import Scheduler
def run(env_name: str,
model_name: str,
hidden_size=256,
total_timesteps=200000,
**kwargs):
# Prepare environment
env = get_env(env_name)
log.info(f'Env: {env}')
# Initialize model
input_shape = space_shape(env.observation_space)
if kwargs.get('dueling_mode', False):
# For duelling mode: adjust output_shape so that model will output a tuple of: (advantages, value)
output_shape = (space_shape(env.action_space), 1)
else:
# For vanilla DQN a model outputs qvalues directly
output_shape = space_shape(env.action_space)
model_class = Models[model_name]
model = model_class(input_shape=input_shape, output_shape=output_shape, hidden_size=hidden_size, norm=False)
log.info(f'Model: {model}')
# Initialize optimizer
autosave_prefix = f'dqn_{model_name.lower()}_{env_name.rsplit("-", 1)[0].lower()}'
trainer = DQN(
env=env,
model=model,
step_delay=0.015,
autosave_dir=WORK_DIR,
autosave_prefix=autosave_prefix,
autosave_interval=timedelta(minutes=1),
log=log,
**kwargs
)
log.info(f'Trainer: {trainer}')
trainer.fit(total_timesteps=total_timesteps, render=True)
trainer.autosave(force=True)
def test(env_name: str, model_name: str, hidden_size=256, **kwargs):
# Prepare environment
env = get_env(env_name)
log.info(f'Env: {env}')
# Initialize model
input_shape = space_shape(env.observation_space)
if kwargs.get('dueling_mode', False):
# For duelling mode: adjust output_shape so that model will output a tuple of: (advantages, value)
output_shape = (space_shape(env.action_space), 1)
else:
# For vanilla DQN a model outputs qvalues directly
output_shape = space_shape(env.action_space)
model_class = Models[model_name]
model = model_class(input_shape=input_shape, output_shape=output_shape, hidden_size=hidden_size, norm=False)
log.info(f'Model: {model}')
# Load model
autosave_prefix = f'dqn_{model_name.lower()}_{env_name.rsplit("-", 1)[0].lower()}'
model_file_name = os.path.join(WORK_DIR, autosave_prefix + '.model.pt')
model.load_state_dict(th.load(model_file_name), strict=True)
state = env.reset()
while True:
action = model(state)
action = action.argmax(dim=-1).squeeze().item()
state, reward, done, info = env.step(action)
env.render()
if done:
state = env.reset()
if __name__ == '__main__':
# run(env_name='CartPole-v0', model_name='fc2')
# run(env_name='Acrobot-v1', model_name='fc2', entropy_factor=1.0)
# run(env_name='BreakoutDeterministic-v4', model_name='convfc1', hidden_size=128)
run(
env_name='LunarLander-v2',
model_name='fc2',
double_mode=True,
dueling_mode=False,
learning_rate=Scheduler((5e-4, 1e-7), 's'),
gamma=0.99,
batch_size=128,
total_timesteps=150000,
anneal_period=int(100000 * 0.12),
prioritized_replay=True,
replay_buffer_capacity=50000,
samples_to_start=10000,
samples_per_iteration=4,
updates_per_iteration=4,
target_update_period=250
)
test(env_name='LunarLander-v2', model_name='fc2')