Skip to content

Commit

Permalink
update runner smac
Browse files Browse the repository at this point in the history
  • Loading branch information
wenzhangliu committed Nov 4, 2023
1 parent 2f292e1 commit 0ba2161
Show file tree
Hide file tree
Showing 17 changed files with 175 additions and 156 deletions.
4 changes: 2 additions & 2 deletions benchmark_marl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
def parse_args():
parser = argparse.ArgumentParser("Run an MARL demo.")
parser.add_argument("--method", type=str, default="mappo")
parser.add_argument("--env", type=str, default="mpe")
parser.add_argument("--env-id", type=str, default="simple_spread_v3")
parser.add_argument("--env", type=str, default="sc2")
parser.add_argument("--env-id", type=str, default="3m")
parser.add_argument("--seed", type=int, default=10)
parser.add_argument("--test", type=int, default=0)
parser.add_argument("--device", type=str, default="cuda:0")
Expand Down
2 changes: 1 addition & 1 deletion xuance/configs/mappo/sc2/1c3s5z.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ gamma: 0.99 # discount factor
# tricks
use_linear_lr_decay: False # if use linear learning rate decay
end_factor_lr_decay: 0.5
use_global_state: False # if use global state to replace joint observations
use_global_state: True # if use global state to calculate values
use_grad_norm: True # gradient normalization
max_grad_norm: 10.0
use_value_clip: True # limit the value range
Expand Down
2 changes: 1 addition & 1 deletion xuance/configs/mappo/sc2/25m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ gamma: 0.99 # discount factor
# tricks
use_linear_lr_decay: False # if use linear learning rate decay
end_factor_lr_decay: 0.5
use_global_state: False # if use global state to replace joint observations
use_global_state: True # if use global state to calculate values
use_grad_norm: True # gradient normalization
max_grad_norm: 10.0
use_value_clip: True # limit the value range
Expand Down
2 changes: 1 addition & 1 deletion xuance/configs/mappo/sc2/2m_vs_1z.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ gamma: 0.99 # discount factor
# tricks
use_linear_lr_decay: False # if use linear learning rate decay
end_factor_lr_decay: 0.5
use_global_state: False # if use global state to replace merged observations
use_global_state: True # if use global state to replace merged observations
use_grad_norm: True # gradient normalization
max_grad_norm: 10.0
use_value_clip: True # limit the value range
Expand Down
2 changes: 1 addition & 1 deletion xuance/configs/mappo/sc2/2s3z.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ gamma: 0.99 # discount factor
# tricks
use_linear_lr_decay: False # if use linear learning rate decay
end_factor_lr_decay: 0.5
use_global_state: False # if use global state to replace joint observations
use_global_state: True # if use global state to calculate values
use_grad_norm: True # gradient normalization
max_grad_norm: 10.0
use_value_clip: True # limit the value range
Expand Down
2 changes: 1 addition & 1 deletion xuance/configs/mappo/sc2/3m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ gamma: 0.99 # discount factor
# tricks
use_linear_lr_decay: False # if use linear learning rate decay
end_factor_lr_decay: 0.5
use_global_state: False # if use global state to replace joint observations
use_global_state: True # if use global state to calculate values
use_grad_norm: True # gradient normalization
max_grad_norm: 10.0
use_value_clip: True # limit the value range
Expand Down
2 changes: 1 addition & 1 deletion xuance/configs/mappo/sc2/5m_vs_6m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ gamma: 0.99 # discount factor
# tricks
use_linear_lr_decay: False # if use linear learning rate decay
end_factor_lr_decay: 0.5
use_global_state: False # if use global state to replace joint observations
use_global_state: True # if use global state to calculate values
use_grad_norm: True # gradient normalization
max_grad_norm: 10.0
use_value_clip: True # limit the value range
Expand Down
2 changes: 1 addition & 1 deletion xuance/configs/mappo/sc2/8m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ gamma: 0.99 # discount factor
# tricks
use_linear_lr_decay: False # if use linear learning rate decay
end_factor_lr_decay: 0.5
use_global_state: False # if use global state to replace joint observations
use_global_state: True # if use global state to calculate values
use_grad_norm: True # gradient normalization
max_grad_norm: 10.0
use_value_clip: True # limit the value range
Expand Down
2 changes: 1 addition & 1 deletion xuance/configs/mappo/sc2/8m_vs_9m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ gamma: 0.99 # discount factor
# tricks
use_linear_lr_decay: False # if use linear learning rate decay
end_factor_lr_decay: 0.5
use_global_state: False # if use global state to replace joint observations
use_global_state: True # if use global state to calculate values
use_grad_norm: True # gradient normalization
max_grad_norm: 10.0
use_value_clip: True # limit the value range
Expand Down
2 changes: 1 addition & 1 deletion xuance/configs/mappo/sc2/MMM2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ gamma: 0.99 # discount factor
# tricks
use_linear_lr_decay: False # if use linear learning rate decay
end_factor_lr_decay: 0.5
use_global_state: False # if use global state to replace joint observations
use_global_state: True # if use global state to calculate values
use_grad_norm: True # gradient normalization
max_grad_norm: 10.0
use_value_clip: True # limit the value range
Expand Down
2 changes: 1 addition & 1 deletion xuance/configs/mappo/sc2/corridor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ gamma: 0.99 # discount factor
# tricks
use_linear_lr_decay: False # if use linear learning rate decay
end_factor_lr_decay: 0.5
use_global_state: False # if use global state to replace joint observations
use_global_state: True # if use global state to calculate values
use_grad_norm: True # gradient normalization
max_grad_norm: 10.0
use_value_clip: True # limit the value range
Expand Down
1 change: 0 additions & 1 deletion xuance/configs/vdac/sc2/2m_vs_1z.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ gamma: 0.99 # discount factor
# tricks
use_linear_lr_decay: False # if use linear learning rate decay
end_factor_lr_decay: 0.5
use_global_state: False # if use global state to replace merged observations
use_grad_norm: True # gradient normalization
max_grad_norm: 10.0
use_value_clip: True # limit the value range
Expand Down
3 changes: 1 addition & 2 deletions xuance/configs/vdac/sc2/3m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ actor_hidden_size: []
critic_hidden_size: []
activation: "ReLU"

mixer: "QMIX" # choices: VDN (sum), QMIX (monotonic)
mixer: "Independent" # choices: Independent, VDN (sum), QMIX (monotonic)
hidden_dim_mixing_net: 32 # hidden units of mixing network
hidden_dim_hyper_net: 32 # hidden units of hyper network

Expand All @@ -45,7 +45,6 @@ gamma: 0.99 # discount factor
# tricks
use_linear_lr_decay: False # if use linear learning rate decay
end_factor_lr_decay: 0.5
use_global_state: False # if use global state to replace joint observations
use_grad_norm: True # gradient normalization
max_grad_norm: 10.0
use_value_clip: True # limit the value range
Expand Down
9 changes: 7 additions & 2 deletions xuance/torch/agents/multi_agent_rl/mappo_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ def __init__(self,
"rnn": config.rnn} if self.use_recurrent else {}
representation = REGISTRY_Representation[config.representation](*input_representation, **kwargs_rnn)
# create representation for critic
input_representation[0] = (config.dim_state,) if self.use_global_state else (config.dim_obs * config.n_agents,)
if self.use_global_state:
input_representation[0] = (config.dim_state + config.dim_obs * config.n_agents,)
else:
input_representation[0] = (config.dim_obs * config.n_agents,)
representation_critic = REGISTRY_Representation[config.representation](*input_representation, **kwargs_rnn)
# create policy
input_policy = get_policy_in_marl(config, (representation, representation_critic))
Expand Down Expand Up @@ -83,7 +86,9 @@ def values(self, obs_n, *rnn_hidden, state=None):
# build critic input
if self.use_global_state:
state = torch.Tensor(state).unsqueeze(1).to(self.device)
critic_in = state.expand(-1, self.n_agents, -1)
obs_n = torch.Tensor(obs_n).view([batch_size, 1, -1]).to(self.device)
critic_in = torch.concat([obs_n.expand(-1, self.n_agents, -1),
state.expand(-1, self.n_agents, -1)], dim=-1)
else:
critic_in = torch.Tensor(obs_n).view([batch_size, 1, -1]).to(self.device)
critic_in = critic_in.expand(-1, self.n_agents, -1)
Expand Down
1 change: 0 additions & 1 deletion xuance/torch/agents/multi_agent_rl/vdac_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def __init__(self,

input_representation = get_repre_in(config)
self.use_recurrent = config.use_recurrent
self.use_global_state = config.use_global_state
# create representation for actor
kwargs_rnn = {"N_recurrent_layers": config.N_recurrent_layers,
"dropout": config.dropout,
Expand Down
6 changes: 5 additions & 1 deletion xuance/torch/learners/multi_agent_rl/mappo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,11 @@ def update_recurrent(self, sample):
# critic loss
rnn_hidden_critic = self.policy.representation_critic.init_hidden(batch_size * self.n_agents)
if self.use_global_state:
_, value_pred = self.policy.get_values(state[:, :, :-1], IDs[:, :, :-1], *rnn_hidden_critic)
critic_in_obs = obs[:, :, :-1].transpose(1, 2).reshape(batch_size, episode_length, -1)
critic_in_obs = critic_in_obs.unsqueeze(1).expand(-1, self.n_agents, -1, -1)
critic_in_state = state[:, :, :-1]
critic_in = torch.concat([critic_in_obs, critic_in_state], dim=-1)
_, value_pred = self.policy.get_values(critic_in, IDs[:, :, :-1], *rnn_hidden_critic)
else:
critic_in = obs[:, :, :-1].transpose(1, 2).reshape(batch_size, episode_length, -1)
critic_in = critic_in.unsqueeze(1).expand(-1, self.n_agents, -1, -1)
Expand Down
Loading

0 comments on commit 0ba2161

Please sign in to comment.