-
Notifications
You must be signed in to change notification settings - Fork 2
/
stable_dqn.py
152 lines (127 loc) · 6.39 KB
/
stable_dqn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import gym
import socnavgym
import torch
from socnavgym.wrappers import DiscreteActions
from stable_baselines3 import DQN
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from agents.models import Transformer
import argparse
from comet_ml import Experiment
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.results_plotter import ts2xy, plot_results
from stable_baselines3.common.utils import safe_mean
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor
class TransformerExtractor(BaseFeaturesExtractor):
def __init__(self, observation_space: gym.spaces.Dict, cnn_output_dim: int = 256):
# TODO we do not know features-dim here before going over all the items, so put something there. This is dirty!
super().__init__(observation_space, features_dim=1)
self.transformer = Transformer(8, 14, 512, 512, None)
# Update the features dim manually
self._features_dim = 512
print("Using transformer for feature extraction")
def preprocess_observation(self, obs):
"""
To convert dict observation to numpy observation
"""
assert(type(obs) == dict)
observation = torch.tensor([], device=obs["goal"].device).float()
if "goal" in obs.keys() : observation = torch.cat((observation, obs["goal"]) , dim=1)
if "humans" in obs.keys() : observation = torch.cat((observation, obs["humans"]) , dim=1)
if "laptops" in obs.keys() : observation = torch.cat((observation, obs["laptops"]) , dim=1)
if "tables" in obs.keys() : observation = torch.cat((observation, obs["tables"]) , dim=1)
if "plants" in obs.keys() : observation = torch.cat((observation, obs["plants"]) , dim=1)
if "walls" in obs.keys():
observation = torch.cat((observation, obs["walls"]), dim=1)
return observation
def postprocess_observation(self, obs):
"""
To convert a one-vector observation into two inputs that can be given to the transformer
"""
if(len(obs.shape) == 1):
obs = obs.reshape(1, -1)
robot_state = obs[:, 0:8].reshape(obs.shape[0], -1, 8)
entity_state = obs[:, 8:].reshape(obs.shape[0], -1, 14)
return robot_state, entity_state
def forward(self, observations):
pre = self.preprocess_observation(observations)
r, e = self.postprocess_observation(pre)
out = self.transformer(r, e)
out = out.squeeze(1)
return out
class CometMLCallback(CheckpointCallback):
"""
A custom callback that derives from ``BaseCallback``.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
def __init__(self, run_name:str, save_path:str, verbose=0):
# super(CometMLCallback, self).__init__(verbose)
super(CometMLCallback, self).__init__(save_freq=25000, save_path=save_path, verbose=verbose)
# Those variables will be accessible in the callback
# (they are defined in the base class)
# The RL model
# self.model = None # type: BaseAlgorithm
# An alias for self.model.get_env(), the environment used for training
# self.training_env = None # type: Union[gym.Env, VecEnv, None]
# Number of time the callback was called
# self.n_calls = 0 # type: int
# self.num_timesteps = 0 # type: int
# local and global variables
# self.locals = None # type: Dict[str, Any]
# self.globals = None # type: Dict[str, Any]
# The logger object, used to report things in the terminal
# self.logger = None # stable_baselines3.common.logger
# # Sometimes, for event callback, it is useful
# # to have access to the parent object
# self.parent = None # type: Optional[BaseCallback]
print("Logging using comet_ml")
self.run_name = run_name
self.experiment = Experiment(
api_key="8U8V63x4zSaEk4vDrtwppe8Vg",
project_name="socnav",
parse_args=False
)
self.experiment.set_name(self.run_name)
def _on_rollout_end(self) -> None:
"""
This event is triggered before updating the policy.
"""
metrics = {
"rollout/ep_rew_mean": safe_mean([ep_info["r"] for ep_info in self.locals['self'].ep_info_buffer]),
"rollout/ep_len_mean": safe_mean([ep_info["l"] for ep_info in self.locals['self'].ep_info_buffer])
}
if len(self.locals['self'].ep_success_buffer) > 0:
metrics["rollout/success_rate"] = safe_mean(self.locals['self'].ep_success_buffer)
l = [
"train/loss",
"train/n_updates",
]
for val in l:
if val in self.logger.name_to_value.keys():
metrics[val] = self.logger.name_to_value[val]
step = self.locals['self'].num_timesteps
self.experiment.log_metrics(metrics, step=step)
ap = argparse.ArgumentParser()
ap.add_argument("-e", "--env_config", help="path to environment config", required=True)
ap.add_argument("-r", "--run_name", help="name of comet_ml run", required=True)
ap.add_argument("-s", "--save_path", help="path to save the model", required=True)
ap.add_argument("-u", "--use_transformer", help="True or False, based on whether you want a transformer based feature extractor", required=True, default=False)
ap.add_argument("-d", "--use_deep_net", help="True or False, based on whether you want a transformer based feature extractor", required=False, default=False)
ap.add_argument("-g", "--gpu", help="gpu id to use", required=False, default="0")
args = vars(ap.parse_args())
env = gym.make("SocNavGym-v1", config=args["env_config"])
env = DiscreteActions(env)
net_arch = {}
if not args["use_deep_net"]:
net_arch = [512, 256, 128, 64]
else:
net_arch = [512, 256, 256, 256, 128, 128, 64]
if args["use_transformer"]:
policy_kwargs = {"net_arch" : net_arch, "features_extractor_class": TransformerExtractor}
else:
policy_kwargs = {"net_arch" : net_arch}
device = 'cuda:'+str(args["gpu"]) if torch.cuda.is_available() else 'cpu'
model = DQN("MultiInputPolicy", env, verbose=0, policy_kwargs=policy_kwargs, device=device)
callback = CometMLCallback(args["run_name"], args["save_path"])
model.learn(total_timesteps=50000*200, callback=callback)
model.save(args["save_path"])