Skip to content

Commit

Permalink
vdac
Browse files Browse the repository at this point in the history
  • Loading branch information
wenzhangliu committed Oct 19, 2023
1 parent 7540d2d commit 097d9aa
Show file tree
Hide file tree
Showing 10 changed files with 278 additions and 62 deletions.
4 changes: 4 additions & 0 deletions xuanpolicy/configs/vdac/sc2/25m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ actor_hidden_size: []
critic_hidden_size: []
activation: "ReLU"

mixer: "QMIX" # choices: VDN (sum), QMIX (monotonic)
hidden_dim_mixing_net: 32 # hidden units of mixing network
hidden_dim_hyper_net: 32 # hidden units of hyper network

seed: 1
parallels: 1
n_size: 128
Expand Down
6 changes: 5 additions & 1 deletion xuanpolicy/configs/vdac/sc2/2m_vs_1z.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ actor_hidden_size: []
critic_hidden_size: []
activation: "ReLU"

mixer: "QMIX" # choices: VDN (sum), QMIX (monotonic)
hidden_dim_mixing_net: 32 # hidden units of mixing network
hidden_dim_hyper_net: 32 # hidden units of hyper network

seed: 1
parallels: 1
n_size: 128
Expand Down Expand Up @@ -60,7 +64,7 @@ train_per_step: True
training_frequency: 1

test_steps: 10000
eval_interval: 100000
eval_interval: 5000
test_episode: 5
log_dir: "./logs/vdac/"
model_dir: "./models/vdac/"
4 changes: 4 additions & 0 deletions xuanpolicy/configs/vdac/sc2/3m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ actor_hidden_size: []
critic_hidden_size: []
activation: "ReLU"

mixer: "QMIX" # choices: VDN (sum), QMIX (monotonic)
hidden_dim_mixing_net: 32 # hidden units of mixing network
hidden_dim_hyper_net: 32 # hidden units of hyper network

seed: 1
parallels: 1
n_size: 128
Expand Down
4 changes: 4 additions & 0 deletions xuanpolicy/configs/vdac/sc2/5m_vs_6m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ actor_hidden_size: []
critic_hidden_size: []
activation: "ReLU"

mixer: "QMIX" # choices: VDN (sum), QMIX (monotonic)
hidden_dim_mixing_net: 32 # hidden units of mixing network
hidden_dim_hyper_net: 32 # hidden units of hyper network

seed: 1
parallels: 1
n_size: 128
Expand Down
4 changes: 4 additions & 0 deletions xuanpolicy/configs/vdac/sc2/8m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ actor_hidden_size: []
critic_hidden_size: []
activation: "ReLU"

mixer: "QMIX" # choices: VDN (sum), QMIX (monotonic)
hidden_dim_mixing_net: 32 # hidden units of mixing network
hidden_dim_hyper_net: 32 # hidden units of hyper network

seed: 1
parallels: 1
n_size: 128
Expand Down
4 changes: 4 additions & 0 deletions xuanpolicy/configs/vdac/sc2/8m_vs_9m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ actor_hidden_size: []
critic_hidden_size: []
activation: "ReLU"

mixer: "QMIX" # choices: VDN (sum), QMIX (monotonic)
hidden_dim_mixing_net: 32 # hidden units of mixing network
hidden_dim_hyper_net: 32 # hidden units of hyper network

seed: 1
parallels: 1
n_size: 128
Expand Down
4 changes: 4 additions & 0 deletions xuanpolicy/configs/vdac/sc2/MMM2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ actor_hidden_size: []
critic_hidden_size: []
activation: "ReLU"

mixer: "QMIX" # choices: VDN (sum), QMIX (monotonic)
hidden_dim_mixing_net: 32 # hidden units of mixing network
hidden_dim_hyper_net: 32 # hidden units of hyper network

seed: 1
parallels: 1
n_size: 128
Expand Down
4 changes: 4 additions & 0 deletions xuanpolicy/configs/vdac/sc2/corridor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ actor_hidden_size: []
critic_hidden_size: []
activation: "ReLU"

mixer: "QMIX" # choices: VDN (sum), QMIX (monotonic)
hidden_dim_mixing_net: 32 # hidden units of mixing network
hidden_dim_hyper_net: 32 # hidden units of hyper network

seed: 1
parallels: 1
n_size: 128
Expand Down
126 changes: 81 additions & 45 deletions xuanpolicy/torch/agents/multi_agent_rl/vdac_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,73 +6,109 @@ def __init__(self,
config: Namespace,
envs: DummyVecEnv_Pettingzoo,
device: Optional[Union[int, str, torch.device]] = None):
self.device = torch.device("cuda" if (torch.cuda.is_available() and config.device in ["gpu", "cuda:0"]) else "cpu")
self.gamma = config.gamma

self.n_envs = envs.num_envs
self.n_size = config.n_size
self.n_epoch = config.n_epoch
self.n_minibatch = config.n_minibatch
if config.state_space is not None:
config.dim_state, state_shape = config.state_space.shape, config.state_space.shape
config.dim_state, state_shape = config.state_space.shape[0], config.state_space.shape
else:
config.dim_state, state_shape = None, None

input_representation = get_repre_in(config)
representation = REGISTRY_Representation[config.representation](*input_representation)
if config.mixer == "VDN":
mixer = VDN_mixer()
elif config.mixer == "QMIX":
mixer = QMIX_mixer(config.dim_state[0], config.hidden_dim_mixing_net, config.hidden_dim_hyper_net,
config.n_agents, self.device)
else:
mixer = None

input_policy = get_policy_in_marl(config, representation, config.agent_keys, mixer)
policy = REGISTRY_Policy[config.policy](*input_policy)
optimizer = torch.optim.Adam(policy.parameters(), config.learning_rate, eps=1e-5)
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.5,
total_iters=get_total_iters(config.agent_name, config))
self.use_recurrent = config.use_recurrent
self.use_global_state = config.use_global_state
# create representation for actor
kwargs_rnn = {"N_recurrent_layers": config.N_recurrent_layers,
"dropout": config.dropout,
"rnn": config.rnn} if self.use_recurrent else {}
representation = REGISTRY_Representation[config.representation](*input_representation, **kwargs_rnn)
# create representation for critic
input_representation[0] = (config.dim_state,) if self.use_global_state else (config.dim_obs * config.n_agents,)
representation_critic = REGISTRY_Representation[config.representation](*input_representation, **kwargs_rnn)
# create policy
input_policy = get_policy_in_marl(config, (representation, representation_critic))
policy = REGISTRY_Policy[config.policy](*input_policy,
use_recurrent=config.use_recurrent,
rnn=config.rnn,
gain=config.gain)
optimizer = torch.optim.Adam(policy.parameters(),
lr=config.learning_rate, eps=1e-5,
weight_decay=config.weight_decay)
self.observation_space = envs.observation_space
self.action_space = envs.action_space
self.representation_info_shape = policy.representation.output_shapes
self.auxiliary_info_shape = {}

if config.state_space is not None:
config.dim_state, state_shape = config.state_space.shape, config.state_space.shape
else:
config.dim_state, state_shape = None, None
memory = MARL_OnPolicyBuffer(state_shape, config.obs_shape, config.act_shape, config.rew_shape,
config.done_shape, envs.num_envs, config.nsteps, config.nminibatch,
config.use_gae, config.use_advnorm, config.gamma, config.lam)
learner = VDAC_Learner(config, policy, optimizer, scheduler,
config.device, config.model_dir, config.gamma)
buffer = MARL_OnPolicyBuffer_RNN if self.use_recurrent else MARL_OnPolicyBuffer
input_buffer = (config.n_agents, config.state_space.shape, config.obs_shape, config.act_shape, config.rew_shape,
config.done_shape, envs.num_envs, config.n_size,
config.use_gae, config.use_advnorm, config.gamma, config.gae_lambda)
memory = buffer(*input_buffer, max_episode_length=envs.max_episode_length, dim_act=config.dim_act)
self.buffer_size = memory.buffer_size
self.batch_size = self.buffer_size // self.n_minibatch

learner = VDAC_Learner(config, policy, optimizer, None, config.device, config.model_dir, config.gamma)
super(VDAC_Agents, self).__init__(config, envs, policy, memory, learner, device,
config.log_dir, config.model_dir)
self.share_values = True if config.rew_shape[0] == 1 else False
self.on_policy = True

def act(self, obs_n, episode, test_mode, state=None, noise=False):
def act(self, obs_n, *rnn_hidden, avail_actions=None, state=None, test_mode=False):
batch_size = len(obs_n)
agents_id = torch.eye(self.n_agents).unsqueeze(0).expand(batch_size, -1, -1).to(self.device)
states, dists, vs = self.policy(obs_n, agents_id)
if self.args.mixer == "VDN":
vs_tot = self.policy.value_tot(vs).repeat(1, self.n_agents).unsqueeze(-1)
obs_in = torch.Tensor(obs_n).view([batch_size, self.n_agents, -1]).to(self.device)
if self.use_recurrent:
batch_agents = batch_size * self.n_agents
hidden_state, dists = self.policy(obs_in.view(batch_agents, 1, -1),
agents_id.view(batch_agents, 1, -1),
*rnn_hidden,
avail_actions=avail_actions.reshape(batch_agents, 1, -1))
actions = dists.stochastic_sample()
log_pi_a = dists.log_prob(actions).reshape(batch_size, self.n_agents)
actions = actions.reshape(batch_size, self.n_agents)
else:
vs_tot = self.policy.value_tot(vs, state).repeat(1, self.n_agents).unsqueeze(-1)
acts = dists.stochastic_sample()
return acts.detach().cpu().numpy(), vs_tot.detach().cpu().numpy()
hidden_state, dists = self.policy(obs_in, agents_id, avail_actions=avail_actions)
actions = dists.stochastic_sample()
log_pi_a = dists.log_prob(actions)
return hidden_state, actions.detach().cpu().numpy(), log_pi_a.detach().cpu().numpy()

def value(self, obs, state):
batch_size = len(state)
def values(self, obs_n, *rnn_hidden, state=None):
batch_size = len(obs_n)
agents_id = torch.eye(self.n_agents).unsqueeze(0).expand(batch_size, -1, -1).to(self.device)
repre_out = self.policy.representation(obs)
critic_input = torch.concat([torch.Tensor(repre_out['state']), agents_id], dim=-1)
values_n = self.policy.critic(critic_input)
values = self.policy.value_tot(values_n, global_state=state).view(-1, 1).repeat(1, self.n_agents).unsqueeze(-1)
return values.detach().cpu().numpy()
# build critic input
if self.use_global_state:
state = torch.Tensor(state).unsqueeze(1).to(self.device)
critic_in = state.expand(-1, self.n_agents, -1)
else:
critic_in = torch.Tensor(obs_n).view([batch_size, 1, -1]).to(self.device)
critic_in = critic_in.expand(-1, self.n_agents, -1)
# get critic values
if self.use_recurrent:
hidden_state, values_n = self.policy.get_values(critic_in.unsqueeze(2), # add a sequence length axis.
agents_id.unsqueeze(2),
*rnn_hidden)
values_n = values_n.squeeze(2)
else:
hidden_state, values_n = self.policy.get_values(critic_in, agents_id)

return hidden_state, values_n.detach().cpu().numpy()

def train(self, i_episode):
def train(self, i_step):
if self.memory.full:
info_train = {}
for _ in range(self.args.nminibatch * self.args.nepoch):
sample = self.memory.sample()
info_train = self.learner.update(sample)
indexes = np.arange(self.buffer_size)
for _ in range(self.n_epoch):
np.random.shuffle(indexes)
for start in range(0, self.buffer_size, self.batch_size):
end = start + self.batch_size
sample_idx = indexes[start:end]
sample = self.memory.sample(sample_idx)
if self.use_recurrent:
info_train = self.learner.update_recurrent(sample)
else:
info_train = self.learner.update(sample)
self.learner.lr_decay(i_step)
self.memory.clear()
return info_train
else:
Expand Down
Loading

0 comments on commit 097d9aa

Please sign in to comment.