Skip to content

Commit

Permalink
Add DDQN (#598)
Browse files Browse the repository at this point in the history
* Refine explore strategy, add prioritized sampling support; add DDQN example; add DQN test (#590)

* Runnable. Should setup a benchmark and test performance.

* Refine logic

* Test DQN on GYM passed

* Refine explore strategy

* Minor

* Minor

* Add Dueling DQN in CIM scenario

* Resolve PR comments

* Add one more explanation

* fix env_sampler eval info list issue

* update version to 0.3.2a4

---------

Co-authored-by: Huoran Li <[email protected]>
Co-authored-by: Jinyu Wang <[email protected]>
  • Loading branch information
3 people authored Oct 27, 2023
1 parent b3c6a58 commit 2977097
Show file tree
Hide file tree
Showing 25 changed files with 541 additions and 334 deletions.
72 changes: 52 additions & 20 deletions examples/cim/rl/algorithms/dqn.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Optional, Tuple

import torch
from torch.optim import RMSprop

from maro.rl.exploration import MultiLinearExplorationScheduler, epsilon_greedy
from maro.rl.exploration import EpsilonGreedy
from maro.rl.model import DiscreteQNet, FullyConnected
from maro.rl.policy import ValueBasedPolicy
from maro.rl.training.algorithms import DQNParams, DQNTrainer
Expand All @@ -23,32 +24,62 @@


class MyQNet(DiscreteQNet):
def __init__(self, state_dim: int, action_num: int) -> None:
def __init__(
self,
state_dim: int,
action_num: int,
dueling_param: Optional[Tuple[dict, dict]] = None,
) -> None:
super(MyQNet, self).__init__(state_dim=state_dim, action_num=action_num)
self._fc = FullyConnected(input_dim=state_dim, output_dim=action_num, **q_net_conf)
self._optim = RMSprop(self._fc.parameters(), lr=learning_rate)

self._use_dueling = dueling_param is not None
self._fc = FullyConnected(input_dim=state_dim, output_dim=0 if self._use_dueling else action_num, **q_net_conf)
if self._use_dueling:
q_kwargs, v_kwargs = dueling_param
self._q = FullyConnected(input_dim=self._fc.output_dim, output_dim=action_num, **q_kwargs)
self._v = FullyConnected(input_dim=self._fc.output_dim, output_dim=1, **v_kwargs)

self._optim = RMSprop(self.parameters(), lr=learning_rate)

def _get_q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor:
return self._fc(states)
logits = self._fc(states)
if self._use_dueling:
q = self._q(logits)
v = self._v(logits)
logits = q - q.mean(dim=1, keepdim=True) + v
return logits


def get_dqn_policy(state_dim: int, action_num: int, name: str) -> ValueBasedPolicy:
q_kwargs = {
"hidden_dims": [128],
"activation": torch.nn.LeakyReLU,
"output_activation": torch.nn.LeakyReLU,
"softmax": False,
"batch_norm": True,
"skip_connection": False,
"head": True,
"dropout_p": 0.0,
}
v_kwargs = {
"hidden_dims": [128],
"activation": torch.nn.LeakyReLU,
"output_activation": None,
"softmax": False,
"batch_norm": True,
"skip_connection": False,
"head": True,
"dropout_p": 0.0,
}

return ValueBasedPolicy(
name=name,
q_net=MyQNet(state_dim, action_num),
exploration_strategy=(epsilon_greedy, {"epsilon": 0.4}),
exploration_scheduling_options=[
(
"epsilon",
MultiLinearExplorationScheduler,
{
"splits": [(2, 0.32)],
"initial_value": 0.4,
"last_ep": 5,
"final_value": 0.0,
},
),
],
q_net=MyQNet(
state_dim,
action_num,
dueling_param=(q_kwargs, v_kwargs),
),
explore_strategy=EpsilonGreedy(epsilon=0.4, num_actions=action_num),
warmup=100,
)

Expand All @@ -64,6 +95,7 @@ def get_dqn(name: str) -> DQNTrainer:
num_epochs=10,
soft_update_coef=0.1,
double=False,
random_overwrite=False,
alpha=1.0,
beta=1.0,
),
)
2 changes: 1 addition & 1 deletion examples/cim/rl/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@

action_num = len(action_shaping_conf["action_space"])

algorithm = "ppo" # ac, ppo, dqn or discrete_maddpg
algorithm = "dqn" # ac, ppo, dqn or discrete_maddpg
16 changes: 2 additions & 14 deletions examples/vm_scheduling/rl/algorithms/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

from maro.rl.exploration import MultiLinearExplorationScheduler
from maro.rl.exploration import EpsilonGreedy
from maro.rl.model import DiscreteQNet, FullyConnected
from maro.rl.policy import ValueBasedPolicy
from maro.rl.training.algorithms import DQNParams, DQNTrainer
Expand Down Expand Up @@ -58,19 +58,7 @@ def get_dqn_policy(state_dim: int, action_num: int, num_features: int, name: str
return ValueBasedPolicy(
name=name,
q_net=MyQNet(state_dim, action_num, num_features),
exploration_strategy=(MaskedEpsGreedy(state_dim, num_features), {"epsilon": 0.4}),
exploration_scheduling_options=[
(
"epsilon",
MultiLinearExplorationScheduler,
{
"splits": [(100, 0.32)],
"initial_value": 0.4,
"last_ep": 400,
"final_value": 0.0,
},
),
],
explore_strategy=EpsilonGreedy(epsilon=0.4, num_actions=action_num),
warmup=100,
)

Expand Down
2 changes: 1 addition & 1 deletion maro/__misc__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
# Licensed under the MIT license.


__version__ = "0.3.2a3"
__version__ = "0.3.2a4"

__data_version__ = "0.2"
12 changes: 4 additions & 8 deletions maro/rl/exploration/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from .scheduling import AbsExplorationScheduler, LinearExplorationScheduler, MultiLinearExplorationScheduler
from .strategies import epsilon_greedy, gaussian_noise, uniform_noise
from .strategies import EpsilonGreedy, ExploreStrategy, LinearExploration

__all__ = [
"AbsExplorationScheduler",
"LinearExplorationScheduler",
"MultiLinearExplorationScheduler",
"epsilon_greedy",
"gaussian_noise",
"uniform_noise",
"ExploreStrategy",
"EpsilonGreedy",
"LinearExploration",
]
127 changes: 0 additions & 127 deletions maro/rl/exploration/scheduling.py

This file was deleted.

Loading

0 comments on commit 2977097

Please sign in to comment.