diff --git a/docs/source/documents/api/learners/marl/coma.rst b/docs/source/documents/api/learners/marl/coma.rst index 7242fe67..ecdbf448 100644 --- a/docs/source/documents/api/learners/marl/coma.rst +++ b/docs/source/documents/api/learners/marl/coma.rst @@ -10,900 +10,46 @@ xxxxxx. **PyTorch:** .. py:class:: - xuance.torch.learners.multi_agent_rl.BasicQhead(state_dim, action_dim, hidden_sizes, normalize, initialize, activation, device) - - :param state_dim: xxxxxx. - :type state_dim: xxxxxx - :param action_dim: xxxxxx. - :type action_dim: xxxxxx - :param hidden_sizes: xxxxxx. - :type hidden_sizes: xxxxxx - :param normalize: xxxxxx. - :type normalize: xxxxxx - :param initialize: xxxxxx. - :type initialize: xxxxxx - :param activation: xxxxxx. - :type activation: xxxxxx - :param device: xxxxxx. - :type device: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.BasicQhead.forward(x) - - xxxxxx. - - :param x: xxxxxx. - :type x: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:class:: - xuance.torch.policies.deterministic.BasicRecurrent(**kwargs) - - :param **kwargs: xxxxxx. - :type **kwargs: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.BasicRecurrent.forward(x, h, c) - - xxxxxx. - - :param x: xxxxxx. - :type x: xxxxxx - :param h: xxxxxx. - :type h: xxxxxx - :param c: xxxxxx. - :type c: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:class:: - xuance.torch.policies.deterministic.DuelQhead(state_dim, action_dim, hidden_sizes, normalize, initialize, activation, device) - - :param state_dim: xxxxxx. - :type state_dim: xxxxxx - :param action_dim: xxxxxx. - :type action_dim: xxxxxx - :param hidden_sizes: xxxxxx. - :type hidden_sizes: xxxxxx - :param normalize: xxxxxx. - :type normalize: xxxxxx - :param initialize: xxxxxx. - :type initialize: xxxxxx - :param activation: xxxxxx. - :type activation: xxxxxx - :param device: xxxxxx. - :type device: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.DuelQhead.forward(x) - - xxxxxx. - - :param x: xxxxxx. - :type x: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:class:: - xuance.torch.policies.deterministic.C51Qhead(state_dim, action_dim, atom_num, hidden_sizes, normalize, initialize, activation, device) - - :param state_dim: xxxxxx. - :type state_dim: xxxxxx - :param action_dim: xxxxxx. - :type action_dim: xxxxxx - :param atom_num: xxxxxx. - :type atom_num: xxxxxx - :param hidden_sizes: xxxxxx. - :type hidden_sizes: xxxxxx - :param normalize: xxxxxx. - :type normalize: xxxxxx - :param initialize: xxxxxx. - :type initialize: xxxxxx - :param activation: xxxxxx. - :type activation: xxxxxx - :param device: xxxxxx. - :type device: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.C51Qhead.forward(x) - - xxxxxx. - - :param x: xxxxxx. - :type x: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:class:: - xuance.torch.policies.deterministic.QRDQNhead(state_dim, action_dim, atom_num, hidden_sizes, normalize, initialize, activation, device) - - :param state_dim: xxxxxx. - :type state_dim: xxxxxx - :param action_dim: xxxxxx. - :type action_dim: xxxxxx - :param atom_num: xxxxxx. - :type atom_num: xxxxxx - :param hidden_sizes: xxxxxx. - :type hidden_sizes: xxxxxx - :param normalize: xxxxxx. - :type normalize: xxxxxx - :param initialize: xxxxxx. - :type initialize: xxxxxx - :param activation: xxxxxx. - :type activation: xxxxxx - :param device: xxxxxx. - :type device: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.QRDQNhead.forward(x) - - xxxxxx. - - :param x: xxxxxx. - :type x: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:class:: - xuance.torch.policies.deterministic.BasicQnetwork(action_space, representation, hidden_size, normalize, initialize, activation, device) - - :param action_space: xxxxxx. - :type action_space: xxxxxx - :param representation: xxxxxx. - :type representation: xxxxxx - :param hidden_size: xxxxxx. - :type hidden_size: xxxxxx - :param normalize: xxxxxx. - :type normalize: xxxxxx - :param initialize: xxxxxx. - :type initialize: xxxxxx - :param activation: xxxxxx. - :type activation: xxxxxx - :param device: xxxxxx. - :type device: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.BasicQnetwork.forward(observation) - - xxxxxx. - - :param observation: xxxxxx. - :type observation: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.BasicQnetwork.target(observation) - - xxxxxx. - - :param observation: xxxxxx. - :type observation: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.BasicQnetwork.copy_target() - - xxxxxx. - - :return: xxxxxx. - :rtype: xxxxxx - -.. py:class:: - xuance.torch.policies.deterministic.DuelQnetwork(action_space, representation, hidden_size, normalize, initialize, activation, device) - - :param action_space: xxxxxx. - :type action_space: xxxxxx - :param representation: xxxxxx. - :type representation: xxxxxx - :param hidden_size: xxxxxx. - :type hidden_size: xxxxxx - :param normalize: xxxxxx. - :type normalize: xxxxxx - :param initialize: xxxxxx. - :type initialize: xxxxxx - :param activation: xxxxxx. - :type activation: xxxxxx - :param device: xxxxxx. - :type device: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.DuelQnetwork.forward(observation) - - xxxxxx. - - :param observation: xxxxxx. - :type observation: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.DuelQnetwork.target(observation) - - xxxxxx. - - :param observation: xxxxxx. - :type observation: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.DuelQnetwork.copy_target() - - xxxxxx. - - :return: xxxxxx. - :rtype: xxxxxx - -.. py:class:: - xuance.torch.policies.deterministic.NoisyQnetwork(action_space, representation, hidden_size, normalize, initialize, activation, device) - - :param action_space: xxxxxx. - :type action_space: xxxxxx - :param representation: xxxxxx. - :type representation: xxxxxx - :param hidden_size: xxxxxx. - :type hidden_size: xxxxxx - :param normalize: xxxxxx. - :type normalize: xxxxxx - :param initialize: xxxxxx. - :type initialize: xxxxxx - :param activation: xxxxxx. - :type activation: xxxxxx - :param device: xxxxxx. - :type device: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.NoisyQnetwork.update_noise(noisy_bound) - - xxxxxx. - - :param noisy_bound: xxxxxx. - :type noisy_bound: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.NoisyQnetwork.forward(observation) - - xxxxxx. - - :param observation: xxxxxx. - :type observation: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.NoisyQnetwork.target(observation) - - xxxxxx. - - :param observation: xxxxxx. - :type observation: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.NoisyQnetwork.copy_target() - - xxxxxx. - - :return: xxxxxx. - :rtype: xxxxxx - -.. py:class:: - xuance.torch.policies.deterministic.C51Qnetwork(action_space, atom_num, vmin, vmax, representation, hidden_size, normalize, initialize, activation, device) - - :param action_space: xxxxxx. - :type action_space: xxxxxx - :param atom_num: xxxxxx. - :type atom_num: xxxxxx - :param vmin: xxxxxx. - :type vmin: xxxxxx - :param vmax: xxxxxx. - :type vmax: xxxxxx - :param representation: xxxxxx. - :type representation: xxxxxx - :param hidden_size: xxxxxx. - :type hidden_size: xxxxxx - :param normalize: xxxxxx. - :type normalize: xxxxxx - :param initialize: xxxxxx. - :type initialize: xxxxxx - :param activation: xxxxxx. - :type activation: xxxxxx - :param device: xxxxxx. - :type device: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.C51Qnetwork.forward(observation) - - xxxxxx. - - :param observation: xxxxxx. - :type observation: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.C51Qnetwork.target(observation) - - xxxxxx. - - :param observation: xxxxxx. - :type observation: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.C51Qnetwork.copy_target() - - xxxxxx. - - :return: xxxxxx. - :rtype: xxxxxx - -.. py:class:: - xuance.torch.policies.deterministic.QRDQN_Network(action_space, quantile_num, representation, hidden_size, normalize, initialize, activation, device) - - :param action_space: xxxxxx. - :type action_space: xxxxxx - :param quantile_num: xxxxxx. - :type quantile_num: xxxxxx - :param representation: xxxxxx. - :type representation: xxxxxx - :param hidden_size: xxxxxx. - :type hidden_size: xxxxxx - :param normalize: xxxxxx. - :type normalize: xxxxxx - :param initialize: xxxxxx. - :type initialize: xxxxxx - :param activation: xxxxxx. - :type activation: xxxxxx - :param device: xxxxxx. - :type device: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.QRDQN_Network.forward(observation) - - xxxxxx. - - :param observation: xxxxxx. - :type observation: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.QRDQN_Network.target(observation) - - xxxxxx. - - :param observation: xxxxxx. - :type observation: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.QRDQN_Network.copy_target() - - xxxxxx. - - :return: xxxxxx. - :rtype: xxxxxx - -.. py:class:: - xuance.torch.policies.deterministic.ActorNet(state_dim, action_dim, hidden_sizes, initialize, activation, device) - - :param state_dim: xxxxxx. - :type state_dim: xxxxxx - :param action_dim: xxxxxx. - :type action_dim: xxxxxx - :param hidden_sizes: xxxxxx. - :type hidden_sizes: xxxxxx - :param initialize: xxxxxx. - :type initialize: xxxxxx - :param activation: xxxxxx. - :type activation: xxxxxx - :param device: xxxxxx. - :type device: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.ActorNet.forward(x) - - xxxxxx. - - :param x: xxxxxx. - :type x: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:class:: - xuance.torch.policies.deterministic.CriticNet(state_dim, action_dim, hidden_sizes, initialize, activation, device) - - :param state_dim: xxxxxx. - :type state_dim: xxxxxx - :param action_dim: xxxxxx. - :type action_dim: xxxxxx - :param hidden_sizes: xxxxxx. - :type hidden_sizes: xxxxxx - :param initialize: xxxxxx. - :type initialize: xxxxxx - :param activation: xxxxxx. - :type activation: xxxxxx - :param device: xxxxxx. - :type device: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.CriticNet.forward(x) - - xxxxxx. - - :param x: xxxxxx. - :type x: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:class:: - xuance.torch.policies.deterministic.DDPGPolicy(action_space, representation, actor_hidden_size, critic_hidden_size, initialize, activation, device) - - :param action_space: xxxxxx. - :type action_space: xxxxxx - :param representation: xxxxxx. - :type representation: xxxxxx - :param actor_hidden_size: xxxxxx. - :type actor_hidden_size: xxxxxx - :param critic_hidden_size: xxxxxx. - :type critic_hidden_size: xxxxxx - :param initialize: xxxxxx. - :type initialize: xxxxxx - :param activation: xxxxxx. - :type activation: xxxxxx - :param device: xxxxxx. - :type device: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.DDPGPolicy.forward(x) - - xxxxxx. - - :param x: xxxxxx. - :type x: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.DDPGPolicy.Qtarget(observation) - - xxxxxx. - - :param observation: xxxxxx. - :type observation: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.DDPGPolicy.Qaction(observation, action) - - xxxxxx. - - :param observation: xxxxxx. - :type observation: xxxxxx - :param action: xxxxxx. - :type action: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.DDPGPolicy.Qpolicy(observation) - - xxxxxx. - - :param observation: xxxxxx. - :type observation: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.DDPGPolicy.soft_update(tau) - - xxxxxx. - - :param tau: xxxxxx. - :type tau: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:class:: - xuance.torch.policies.deterministic.TD3Policy(action_space, representation, actor_hidden_size, critic_hidden_size, initialize, activation, device) - - :param action_space: xxxxxx. - :type action_space: xxxxxx - :param representation: xxxxxx. - :type representation: xxxxxx - :param actor_hidden_size: xxxxxx. - :type actor_hidden_size: xxxxxx - :param critic_hidden_size: xxxxxx. - :type critic_hidden_size: xxxxxx - :param initialize: xxxxxx. - :type initialize: xxxxxx - :param activation: xxxxxx. - :type activation: xxxxxx - :param device: xxxxxx. - :type device: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.TD3Policy.action(observation) - - xxxxxx. - - :param observation: xxxxxx. - :type observation: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.TD3Policy.Qtarget(observation) - - xxxxxx. - - :param observation: xxxxxx. - :type observation: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.TD3Policy.Qaction(observation, action) - - xxxxxx. - - :param observation: xxxxxx. - :type observation: xxxxxx - :param action: xxxxxx. - :type action: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.TD3Policy.Qpolicy(observation) - - xxxxxx. - - :param observation: xxxxxx. - :type observation: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.TD3Policy.soft_update(tau) - - xxxxxx. - - :param tau: xxxxxx. - :type tau: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:class:: - xuance.torch.policies.deterministic.PDQNPolicy(observation_space, action_space, representation, conactor_hidden_size, qnetwork_hidden_size, normalize, initialize, activation, device) - - :param observation_space: xxxxxx. - :type observation_space: xxxxxx - :param action_space: xxxxxx. - :type action_space: xxxxxx - :param representation: xxxxxx. - :type representation: xxxxxx - :param conactor_hidden_size: xxxxxx. - :type conactor_hidden_size: xxxxxx - :param qnetwork_hidden_size: xxxxxx. - :type qnetwork_hidden_size: xxxxxx - :param normalize: xxxxxx. - :type normalize: xxxxxx - :param initialize: xxxxxx. - :type initialize: xxxxxx - :param activation: xxxxxx. - :type activation: xxxxxx - :param device: xxxxxx. - :type device: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.PDQNPolicy.Atarget(state) - - xxxxxx. - - :param state: xxxxxx. - :type state: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.PDQNPolicy.con_action(state) - - xxxxxx. - - :param state: xxxxxx. - :type state: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.PDQNPolicy.Qtarget(state, action) - - xxxxxx. - - :param state: xxxxxx. - :type state: xxxxxx - :param action: xxxxxx. - :type action: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.PDQNPolicy.Qeval(state, action) - - xxxxxx. - - :param state: xxxxxx. - :type state: xxxxxx - :param action: xxxxxx. - :type action: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.PDQNPolicy.Qpolicy(state) - - xxxxxx. - - :param state: xxxxxx. - :type state: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.PDQNPolicy.soft_update(tau) - - xxxxxx. - - :param tau: xxxxxx. - :type tau: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:class:: - xuance.torch.policies.deterministic.MPDQNPolicy(observation_space, action_space, representation, conactor_hidden_size, qnetwork_hidden_size, normalize, initialize, activation, device) - - :param observation_space: xxxxxx. - :type observation_space: xxxxxx - :param action_space: xxxxxx. - :type action_space: xxxxxx - :param representation: xxxxxx. - :type representation: xxxxxx - :param conactor_hidden_size: xxxxxx. - :type conactor_hidden_size: xxxxxx - :param qnetwork_hidden_size: xxxxxx. - :type qnetwork_hidden_size: xxxxxx - :param normalize: xxxxxx. - :type normalize: xxxxxx - :param initialize: xxxxxx. - :type initialize: xxxxxx - :param activation: xxxxxx. - :type activation: xxxxxx - :param device: xxxxxx. - :type device: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.MPDQNPolicy.Atarget(state) - - xxxxxx. - - :param state: xxxxxx. - :type state: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.MPDQNPolicy.con_action(state) - - xxxxxx. - - :param state: xxxxxx. - :type state: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.MPDQNPolicy.Qtarget(state, action) - - xxxxxx. - - :param state: xxxxxx. - :type state: xxxxxx - :param action: xxxxxx. - :type action: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.MPDQNPolicy.Qeval(state, action) - - xxxxxx. - - :param state: xxxxxx. - :type state: xxxxxx - :param action: xxxxxx. - :type action: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.MPDQNPolicy.Qpolicy(state) - - xxxxxx. - - :param state: xxxxxx. - :type state: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.MPDQNPolicy.soft_update(tau) - - xxxxxx. - - :param tau: xxxxxx. - :type tau: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:class:: - xuance.torch.policies.deterministic.SPDQNPolicy(observation_space, action_space, representation, conactor_hidden_size, qnetwork_hidden_size, normalize, initialize, activation, device) - - :param observation_space: xxxxxx. - :type observation_space: xxxxxx - :param action_space: xxxxxx. - :type action_space: xxxxxx - :param representation: xxxxxx. - :type representation: xxxxxx - :param conactor_hidden_size: xxxxxx. - :type conactor_hidden_size: xxxxxx - :param qnetwork_hidden_size: xxxxxx. - :type qnetwork_hidden_size: xxxxxx - :param normalize: xxxxxx. - :type normalize: xxxxxx - :param initialize: xxxxxx. - :type initialize: xxxxxx - :param activation: xxxxxx. - :type activation: xxxxxx + xuance.torch.learners.multi_agent_rl.coma_learner.COMA_Learner(config, policy, optimizer, scheduler, device, model_dir, gamma, sync_frequency) + + :param config: xxxxxx. + :type config: xxxxxx + :param policy: xxxxxx. + :type policy: xxxxxx + :param optimizer: xxxxxx. + :type optimizer: xxxxxx + :param scheduler: xxxxxx. + :type scheduler: xxxxxx :param device: xxxxxx. :type device: xxxxxx + :param model_dir: xxxxxx. + :type model_dir: xxxxxx + :param gamma: xxxxxx. + :type gamma: xxxxxx + :param sync_frequency: xxxxxx. + :type sync_frequency: xxxxxx .. py:function:: - xuance.torch.policies.deterministic.SPDQNPolicy.Atarget(state) - - xxxxxx. - - :param state: xxxxxx. - :type state: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.SPDQNPolicy.con_action(state) - - xxxxxx. - - :param state: xxxxxx. - :type state: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.SPDQNPolicy.Qtarget(state, action) - - xxxxxx. - - :param state: xxxxxx. - :type state: xxxxxx - :param action: xxxxxx. - :type action: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.SPDQNPolicy.Qeval(state, action) - - xxxxxx. - - :param state: xxxxxx. - :type state: xxxxxx - :param action: xxxxxx. - :type action: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.SPDQNPolicy.Qpolicy(state) - - xxxxxx. - - :param state: xxxxxx. - :type state: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.SPDQNPolicy.soft_update(tau) - - xxxxxx. - - :param tau: xxxxxx. - :type tau: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:class:: - xuance.torch.policies.deterministic.DRQNPolicy(action_space, representation, **kwargs) - - :param action_space: xxxxxx. - :type action_space: xxxxxx - :param representation: xxxxxx. - :type representation: xxxxxx - :param **kwargs: xxxxxx. - :type **kwargs: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.DRQNPolicy.forward(observation, *rnn_hidden) - - xxxxxx. - - :param observation: xxxxxx. - :type observation: xxxxxx - :param *rnn_hidden: xxxxxx. - :type *rnn_hidden: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.DRQNPolicy.target(observation, *rnn_hidden) - - xxxxxx. - - :param observation: xxxxxx. - :type observation: xxxxxx - :param *rnn_hidden: xxxxxx. - :type *rnn_hidden: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.DRQNPolicy.init_hidden(batch) - - xxxxxx. - - :param batch: xxxxxx. - :type batch: xxxxxx - :return: xxxxxx. - :rtype: xxxxxx - -.. py:function:: - xuance.torch.policies.deterministic.DRQNPolicy.init_hidden_item(rnn_hidden, i) + xuance.torch.learners.multi_agent_rl.coma_learner.COMA_Learner.update(sample, epsilon) xxxxxx. - :param rnn_hidden: xxxxxx. - :type rnn_hidden: xxxxxx - :param i: xxxxxx. - :type i: xxxxxx + :param sample: xxxxxx. + :type sample: xxxxxx + :param epsilon: xxxxxx. + :type epsilon: xxxxxx :return: xxxxxx. :rtype: xxxxxx .. py:function:: - xuance.torch.policies.deterministic.DRQNPolicy.copy_target() + xuance.torch.learners.multi_agent_rl.coma_learner.COMA_Learner.update_recurrent(sample, epsilon) xxxxxx. + :param sample: xxxxxx. + :type sample: xxxxxx + :param epsilon: xxxxxx. + :type epsilon: xxxxxx :return: xxxxxx. :rtype: xxxxxx @@ -932,847 +78,202 @@ Source Code .. code-block:: python - from xuance.torch.policies import * - from xuance.torch.utils import * - from xuance.torch.representations import Basic_Identical - - - class BasicQhead(nn.Module): - def __init__(self, - state_dim: int, - action_dim: int, - hidden_sizes: Sequence[int], - normalize: Optional[ModuleType] = None, - initialize: Optional[Callable[..., torch.Tensor]] = None, - activation: Optional[ModuleType] = None, - device: Optional[Union[str, int, torch.device]] = None): - super(BasicQhead, self).__init__() - layers = [] - input_shape = (state_dim,) - for h in hidden_sizes: - mlp, input_shape = mlp_block(input_shape[0], h, normalize, activation, initialize, device) - layers.extend(mlp) - layers.extend(mlp_block(input_shape[0], action_dim, None, None, None, device)[0]) - self.model = nn.Sequential(*layers) - - def forward(self, x: torch.Tensor): - return self.model(x) - - - class BasicRecurrent(nn.Module): - def __init__(self, **kwargs): - super(BasicRecurrent, self).__init__() - self.lstm = False - if kwargs["rnn"] == "GRU": - output = gru_block(kwargs["input_dim"], - kwargs["recurrent_hidden_size"], - kwargs["recurrent_layer_N"], - kwargs["dropout"], - kwargs["initialize"], - kwargs["device"]) - elif kwargs["rnn"] == "LSTM": - self.lstm = True - output = lstm_block(kwargs["input_dim"], - kwargs["recurrent_hidden_size"], - kwargs["recurrent_layer_N"], - kwargs["dropout"], - kwargs["initialize"], - kwargs["device"]) - else: - raise "Unknown recurrent module!" - self.rnn_layer = output - fc_layer = mlp_block(kwargs["recurrent_hidden_size"], kwargs["action_dim"], None, None, None, kwargs["device"])[0] - self.model = nn.Sequential(*fc_layer) - - def forward(self, x: torch.Tensor, h: torch.Tensor, c: torch.Tensor = None): - self.rnn_layer.flatten_parameters() - if self.lstm: - output, (hn, cn) = self.rnn_layer(x, (h, c)) - return hn, cn, self.model(output) - else: - output, hn = self.rnn_layer(x, h) - return hn, self.model(output) - - - class DuelQhead(nn.Module): - def __init__(self, - state_dim: int, - action_dim: int, - hidden_sizes: Sequence[int], - normalize: Optional[ModuleType] = None, - initialize: Optional[Callable[..., torch.Tensor]] = None, - activation: Optional[ModuleType] = None, - device: Optional[Union[str, int, torch.device]] = None): - super(DuelQhead, self).__init__() - v_layers = [] - input_shape = (state_dim,) - for h in hidden_sizes: - v_mlp, input_shape = mlp_block(input_shape[0], h // 2, normalize, activation, initialize, device) - v_layers.extend(v_mlp) - v_layers.extend(mlp_block(input_shape[0], 1, None, None, None, device)[0]) - a_layers = [] - input_shape = (state_dim,) - for h in hidden_sizes: - a_mlp, input_shape = mlp_block(input_shape[0], h // 2, normalize, activation, initialize, device) - a_layers.extend(a_mlp) - a_layers.extend(mlp_block(input_shape[0], action_dim, None, None, None, device)[0]) - self.a_model = nn.Sequential(*a_layers) - self.v_model = nn.Sequential(*v_layers) - - def forward(self, x: torch.Tensor): - v = self.v_model(x) - a = self.a_model(x) - q = v + (a - a.mean(dim=-1).unsqueeze(dim=-1)) - return q - - - class C51Qhead(nn.Module): - def __init__(self, - state_dim: int, - action_dim: int, - atom_num: int, - hidden_sizes: Sequence[int], - normalize: Optional[ModuleType] = None, - initialize: Optional[Callable[..., torch.Tensor]] = None, - activation: Optional[ModuleType] = None, - device: Optional[Union[str, int, torch.device]] = None): - super(C51Qhead, self).__init__() - self.action_dim = action_dim - self.atom_num = atom_num - layers = [] - input_shape = (state_dim,) - for h in hidden_sizes: - mlp, input_shape = mlp_block(input_shape[0], h, normalize, activation, initialize, device) - layers.extend(mlp) - layers.extend(mlp_block(input_shape[0], action_dim * atom_num, None, None, None, device)[0]) - self.model = nn.Sequential(*layers) - - def forward(self, x: torch.Tensor): - dist_logits = self.model(x).view(-1, self.action_dim, self.atom_num) - dist_probs = F.softmax(dist_logits, dim=-1) - return dist_probs - - - class QRDQNhead(nn.Module): - def __init__(self, - state_dim: int, - action_dim: int, - atom_num: int, - hidden_sizes: Sequence[int], - normalize: Optional[ModuleType] = None, - initialize: Optional[Callable[..., torch.Tensor]] = None, - activation: Optional[ModuleType] = None, - device: Optional[Union[str, int, torch.device]] = None): - super(QRDQNhead, self).__init__() - self.action_dim = action_dim - self.atom_num = atom_num - layers = [] - input_shape = (state_dim,) - for h in hidden_sizes: - mlp, input_shape = mlp_block(input_shape[0], h, normalize, activation, initialize, device) - layers.extend(mlp) - layers.extend(mlp_block(input_shape[0], action_dim * atom_num, None, None, None, device)[0]) - self.model = nn.Sequential(*layers) - - def forward(self, x: torch.Tensor): - quantiles = self.model(x).view(-1, self.action_dim, self.atom_num) - return quantiles - - - class BasicQnetwork(nn.Module): - def __init__(self, - action_space: Discrete, - representation: nn.Module, - hidden_size: Sequence[int] = None, - normalize: Optional[ModuleType] = None, - initialize: Optional[Callable[..., torch.Tensor]] = None, - activation: Optional[ModuleType] = None, - device: Optional[Union[str, int, torch.device]] = None): - super(BasicQnetwork, self).__init__() - self.action_dim = action_space.n - self.representation = representation - self.target_representation = copy.deepcopy(representation) - self.representation_info_shape = self.representation.output_shapes - self.eval_Qhead = BasicQhead(self.representation.output_shapes['state'][0], self.action_dim, hidden_size, - normalize, initialize, activation, device) - self.target_Qhead = copy.deepcopy(self.eval_Qhead) - - def forward(self, observation: Union[np.ndarray, dict]): - outputs = self.representation(observation) - evalQ = self.eval_Qhead(outputs['state']) - argmax_action = evalQ.argmax(dim=-1) - return outputs, argmax_action, evalQ - - def target(self, observation: Union[np.ndarray, dict]): - outputs_target = self.target_representation(observation) - targetQ = self.target_Qhead(outputs_target['state']) - argmax_action = targetQ.argmax(dim=-1) - return outputs_target, argmax_action.detach(), targetQ.detach() - - def copy_target(self): - for ep, tp in zip(self.representation.parameters(), self.target_representation.parameters()): - tp.data.copy_(ep) - for ep, tp in zip(self.eval_Qhead.parameters(), self.target_Qhead.parameters()): - tp.data.copy_(ep) - - - class DuelQnetwork(nn.Module): - def __init__(self, - action_space: Space, - representation: nn.Module, - hidden_size: Sequence[int] = None, - normalize: Optional[ModuleType] = None, - initialize: Optional[Callable[..., torch.Tensor]] = None, - activation: Optional[ModuleType] = None, - device: Optional[Union[str, int, torch.device]] = None): - super(DuelQnetwork, self).__init__() - self.action_dim = action_space.n - self.representation = representation - self.target_representation = copy.deepcopy(representation) - self.representation_info_shape = self.representation.output_shapes - self.eval_Qhead = DuelQhead(self.representation.output_shapes['state'][0], self.action_dim, hidden_size, - normalize, initialize, activation, device) - self.target_Qhead = copy.deepcopy(self.eval_Qhead) - - def forward(self, observation: Union[np.ndarray, dict]): - outputs = self.representation(observation) - evalQ = self.eval_Qhead(outputs['state']) - argmax_action = evalQ.argmax(dim=-1) - return outputs, argmax_action, evalQ - - def target(self, observation: Union[np.ndarray, dict]): - outputs = self.target_representation(observation) - targetQ = self.target_Qhead(outputs['state']) - argmax_action = targetQ.argmax(dim=-1) - return outputs, argmax_action, targetQ - - def copy_target(self): - for ep, tp in zip(self.representation.parameters(), self.target_representation.parameters()): - tp.data.copy_(ep) - for ep, tp in zip(self.eval_Qhead.parameters(), self.target_Qhead.parameters()): - tp.data.copy_(ep) - - - class NoisyQnetwork(nn.Module): - def __init__(self, - action_space: Discrete, - representation: nn.Module, - hidden_size: Sequence[int] = None, - normalize: Optional[ModuleType] = None, - initialize: Optional[Callable[..., torch.Tensor]] = None, - activation: Optional[ModuleType] = None, - device: Optional[Union[str, int, torch.device]] = None): - super(NoisyQnetwork, self).__init__() - self.action_dim = action_space.n - self.representation = representation - self.target_representation = copy.deepcopy(representation) - self.representation_info_shape = self.representation.output_shapes - self.eval_Qhead = BasicQhead(self.representation.output_shapes['state'][0], self.action_dim, hidden_size, - normalize, initialize, activation, device) - self.target_Qhead = copy.deepcopy(self.eval_Qhead) - self.noise_scale = 0.0 - - def update_noise(self, noisy_bound: float = 0.0): - self.eval_noise_parameter = [] - self.target_noise_parameter = [] - for parameter in self.eval_Qhead.parameters(): - self.eval_noise_parameter.append(torch.randn_like(parameter) * noisy_bound) - self.target_noise_parameter.append(torch.randn_like(parameter) * noisy_bound) - - def forward(self, observation: Union[np.ndarray, dict]): - outputs = self.representation(observation) - self.update_noise(self.noise_scale) - for parameter, noise_param in zip(self.eval_Qhead.parameters(), self.eval_noise_parameter): - parameter.data.copy_(parameter.data + noise_param) - evalQ = self.eval_Qhead(outputs['state']) - argmax_action = evalQ.argmax(dim=-1) - return outputs, argmax_action, evalQ - - def target(self, observation: Union[np.ndarray, dict]): - outputs = self.target_representation(observation) - self.update_noise(self.noise_scale) - for parameter, noise_param in zip(self.target_Qhead.parameters(), self.target_noise_parameter): - parameter.data.copy_(parameter.data + noise_param) - targetQ = self.target_Qhead(outputs['state']) - argmax_action = targetQ.argmax(dim=-1) - return outputs, argmax_action, targetQ.detach() - - def copy_target(self): - for ep, tp in zip(self.representation.parameters(), self.target_representation.parameters()): - tp.data.copy_(ep) - for ep, tp in zip(self.eval_Qhead.parameters(), self.target_Qhead.parameters()): - tp.data.copy_(ep) - - - class C51Qnetwork(nn.Module): - def __init__(self, - action_space: Discrete, - atom_num: int, - vmin: float, - vmax: float, - representation: nn.Module, - hidden_size: Sequence[int] = None, - normalize: Optional[ModuleType] = None, - initialize: Optional[Callable[..., torch.Tensor]] = None, - activation: Optional[ModuleType] = None, - device: Optional[Union[str, int, torch.device]] = None): - super(C51Qnetwork, self).__init__() - self.action_dim = action_space.n - self.atom_num = atom_num - self.vmin = vmin - self.vmax = vmax - self.representation = representation - self.target_representation = copy.deepcopy(representation) - self.representation_info_shape = self.representation.output_shapes - self.eval_Zhead = C51Qhead(self.representation.output_shapes['state'][0], self.action_dim, self.atom_num, - hidden_size, - normalize, initialize, activation, device) - self.target_Zhead = copy.deepcopy(self.eval_Zhead) - self.supports = torch.nn.Parameter(torch.linspace(self.vmin, self.vmax, self.atom_num), requires_grad=False).to( - device) - self.deltaz = (vmax - vmin) / (atom_num - 1) - - def forward(self, observation: Union[np.ndarray, dict]): - outputs = self.representation(observation) - eval_Z = self.eval_Zhead(outputs['state']) - eval_Q = (self.supports * eval_Z).sum(-1) - argmax_action = eval_Q.argmax(dim=-1) - return outputs, argmax_action, eval_Z - - def target(self, observation: Union[np.ndarray, dict]): - outputs = self.target_representation(observation) - target_Z = self.target_Zhead(outputs['state']) - target_Q = (self.supports * target_Z).sum(-1) - argmax_action = target_Q.argmax(dim=-1) - return outputs, argmax_action, target_Z - - def copy_target(self): - for ep, tp in zip(self.representation.parameters(), self.target_representation.parameters()): - tp.data.copy_(ep) - for ep, tp in zip(self.eval_Zhead.parameters(), self.target_Zhead.parameters()): - tp.data.copy_(ep) - - - class QRDQN_Network(nn.Module): - def __init__(self, - action_space: Discrete, - quantile_num: int, - representation: nn.Module, - hidden_size: Sequence[int] = None, - normalize: Optional[ModuleType] = None, - initialize: Optional[Callable[..., torch.Tensor]] = None, - activation: Optional[ModuleType] = None, - device: Optional[Union[str, int, torch.device]] = None): - super(QRDQN_Network, self).__init__() - self.action_dim = action_space.n - self.quantile_num = quantile_num - self.representation = representation - self.target_representation = copy.deepcopy(representation) - self.representation_info_shape = self.representation.output_shapes - self.eval_Zhead = QRDQNhead(self.representation.output_shapes['state'][0], self.action_dim, self.quantile_num, - hidden_size, - normalize, initialize, activation, device) - self.target_Zhead = copy.deepcopy(self.eval_Zhead) - - def forward(self, observation: Union[np.ndarray, dict]): - outputs = self.representation(observation) - eval_Z = self.eval_Zhead(outputs['state']) - eval_Q = eval_Z.mean(dim=-1) - argmax_action = eval_Q.argmax(dim=-1) - return outputs, argmax_action, eval_Z - - def target(self, observation: Union[np.ndarray, dict]): - outputs = self.target_representation(observation) - target_Z = self.target_Zhead(outputs['state']) - target_Q = target_Z.mean(dim=-1) - argmax_action = target_Q.argmax(dim=-1) - return outputs, argmax_action, target_Z - - def copy_target(self): - for ep, tp in zip(self.representation.parameters(), self.target_representation.parameters()): - tp.data.copy_(ep) - for ep, tp in zip(self.eval_Zhead.parameters(), self.target_Zhead.parameters()): - tp.data.copy_(ep) - - - class ActorNet(nn.Module): - def __init__(self, - state_dim: int, - action_dim: int, - hidden_sizes: Sequence[int], - initialize: Optional[Callable[..., torch.Tensor]] = None, - activation: Optional[ModuleType] = None, - device: Optional[Union[str, int, torch.device]] = None): - super(ActorNet, self).__init__() - layers = [] - input_shape = (state_dim,) - for h in hidden_sizes: - mlp, input_shape = mlp_block(input_shape[0], h, None, activation, initialize, device) - layers.extend(mlp) - layers.extend(mlp_block(input_shape[0], action_dim, None, nn.Tanh, initialize, device)[0]) - self.model = nn.Sequential(*layers) - - def forward(self, x: torch.tensor): - return self.model(x) - - - class CriticNet(nn.Module): - def __init__(self, - state_dim: int, - action_dim: int, - hidden_sizes: Sequence[int], - initialize: Optional[Callable[..., torch.Tensor]] = None, - activation: Optional[ModuleType] = None, - device: Optional[Union[str, int, torch.device]] = None): - super(CriticNet, self).__init__() - layers = [] - input_shape = (state_dim + action_dim,) - for h in hidden_sizes: - mlp, input_shape = mlp_block(input_shape[0], h, None, activation, initialize, device) - layers.extend(mlp) - layers.extend(mlp_block(input_shape[0], 1, None, None, initialize, device)[0]) - self.model = nn.Sequential(*layers) - - def forward(self, x: torch.tensor, a: torch.tensor): - return self.model(torch.concat((x, a), dim=-1))[:, 0] - - - class DDPGPolicy(nn.Module): - def __init__(self, - action_space: Space, - representation: nn.Module, - actor_hidden_size: Sequence[int], - critic_hidden_size: Sequence[int], - initialize: Optional[Callable[..., torch.Tensor]] = None, - activation: Optional[ModuleType] = None, - device: Optional[Union[str, int, torch.device]] = None): - super(DDPGPolicy, self).__init__() - self.action_dim = action_space.shape[0] - self.representation_info_shape = representation.output_shapes - self.representation = representation - self.actor = ActorNet(representation.output_shapes['state'][0], self.action_dim, actor_hidden_size, initialize, - activation, device) - self.critic = CriticNet(representation.output_shapes['state'][0], self.action_dim, critic_hidden_size, - initialize, activation, device) - self.target_actor = copy.deepcopy(self.actor) - self.target_critic = copy.deepcopy(self.critic) - - def forward(self, observation: Union[np.ndarray, dict]): - outputs = self.representation(observation) - act = self.actor(outputs['state']) - return outputs, act - - def Qtarget(self, observation: Union[np.ndarray, dict]): - outputs = self.representation(observation) - act = self.target_actor(outputs['state']) - # noise = torch.randn_like(act).clamp(-1, 1) * 0.1 - # act = (act + noise).clamp(-1, 1) - return self.target_critic(outputs['state'], act) - - def Qaction(self, observation: Union[np.ndarray, dict], action: torch.Tensor): - outputs = self.representation(observation) - return self.critic(outputs['state'], action) - - def Qpolicy(self, observation: Union[np.ndarray, dict]): - outputs = self.representation(observation) - return self.critic(outputs['state'], self.actor(outputs['state'])) - - def soft_update(self, tau=0.005): - for ep, tp in zip(self.actor.parameters(), self.target_actor.parameters()): - tp.data.mul_(1 - tau) - tp.data.add_(tau * ep.data) - for ep, tp in zip(self.critic.parameters(), self.target_critic.parameters()): - tp.data.mul_(1 - tau) - tp.data.add_(tau * ep.data) - - - class TD3Policy(nn.Module): - def __init__(self, - action_space: Space, - representation: nn.Module, - actor_hidden_size: Sequence[int], - critic_hidden_size: Sequence[int], - normalize: Optional[ModuleType] = None, - initialize: Optional[Callable[..., torch.Tensor]] = None, - activation: Optional[ModuleType] = None, - device: Optional[Union[str, int, torch.device]] = None): - super(TD3Policy, self).__init__() - self.action_dim = action_space.shape[0] - self.representation = representation - self.representation_info_shape = self.representation.output_shapes - self.actor = ActorNet(representation.output_shapes['state'][0], self.action_dim, actor_hidden_size, - initialize, activation, device) - self.criticA = CriticNet(representation.output_shapes['state'][0], self.action_dim, critic_hidden_size, - initialize, activation, device) - self.criticB = CriticNet(representation.output_shapes['state'][0], self.action_dim, critic_hidden_size, - initialize, activation, device) - self.target_actor = copy.deepcopy(self.actor) - self.target_criticA = copy.deepcopy(self.criticA) - self.target_criticB = copy.deepcopy(self.criticB) - - def action(self, observation: Union[np.ndarray, dict]): - outputs = self.representation(observation) - act = self.actor(outputs['state']) - return outputs, act - - def Qtarget(self, observation: Union[np.ndarray, dict]): - outputs = self.representation(observation) - act = self.target_actor(outputs['state']) - noise = torch.randn_like(act).clamp(-0.1, 0.1) * 0.1 - act = (act + noise).clamp(-1, 1) - qa = self.target_criticA(outputs['state'], act).unsqueeze(dim=1) - qb = self.target_criticB(outputs['state'], act).unsqueeze(dim=1) - mim_q = torch.minimum(qa, qb) - return outputs, mim_q - - def Qaction(self, observation: Union[np.ndarray, dict], action: torch.Tensor): - outputs = self.representation(observation) - qa = self.criticA(outputs['state'], action).unsqueeze(dim=1) - qb = self.criticB(outputs['state'], action).unsqueeze(dim=1) - return outputs, torch.cat((qa, qb), axis=-1) - - def Qpolicy(self, observation: Union[np.ndarray, dict]): - outputs = self.representation(observation) - act = self.actor(outputs['state']) - qa = self.criticA(outputs['state'], act).unsqueeze(dim=1) - qb = self.criticB(outputs['state'], act).unsqueeze(dim=1) - return outputs, (qa + qb) / 2.0 - - def soft_update(self, tau=0.005): - for ep, tp in zip(self.actor.parameters(), self.target_actor.parameters()): - tp.data.mul_(1 - tau) - tp.data.add_(tau * ep.data) - for ep, tp in zip(self.criticA.parameters(), self.target_criticA.parameters()): - tp.data.mul_(1 - tau) - tp.data.add_(tau * ep.data) - for ep, tp in zip(self.criticB.parameters(), self.target_criticB.parameters()): - tp.data.mul_(1 - tau) - tp.data.add_(tau * ep.data) - - - class PDQNPolicy(nn.Module): - def __init__(self, - observation_space, - action_space, - representation: nn.Module, - conactor_hidden_size: Sequence[int], - qnetwork_hidden_size: Sequence[int], - normalize: Optional[ModuleType] = None, - initialize: Optional[Callable[..., torch.Tensor]] = None, - activation: Optional[ModuleType] = None, - device: Optional[Union[str, int, torch.device]] = None): - super(PDQNPolicy, self).__init__() - self.representation = representation - self.target_representation = copy.deepcopy(representation) - self.observation_space = observation_space - self.action_space = action_space - self.num_disact = self.action_space.spaces[0].n - self.conact_sizes = np.array([self.action_space.spaces[i].shape[0] for i in range(1, self.num_disact + 1)]) - self.conact_size = int(self.conact_sizes.sum()) - - self.qnetwork = BasicQhead(self.observation_space.shape[0] + self.conact_size, self.num_disact, - qnetwork_hidden_size, normalize, - initialize, torch.nn.modules.activation.ReLU, device) - self.conactor = ActorNet(self.observation_space.shape[0], self.conact_size, conactor_hidden_size, - initialize, torch.nn.modules.activation.ReLU, device) - self.target_conactor = copy.deepcopy(self.conactor) - self.target_qnetwork = copy.deepcopy(self.qnetwork) - - def Atarget(self, state): - target_conact = self.target_conactor(state) - return target_conact - - def con_action(self, state): - conaction = self.conactor(state) - return conaction - - def Qtarget(self, state, action): - input_q = torch.cat((state, action), dim=1) - target_q = self.target_qnetwork(input_q) - return target_q - - def Qeval(self, state, action): - input_q = torch.cat((state, action), dim=1) - eval_q = self.qnetwork(input_q) - return eval_q - - def Qpolicy(self, state): - conact = self.conactor(state) - input_q = torch.cat((state, conact), dim=1) - policy_q = torch.sum(self.qnetwork(input_q)) - return policy_q - - def soft_update(self, tau=0.005): - for ep, tp in zip(self.representation.parameters(), self.target_representation.parameters()): - tp.data.mul_(1 - tau) - tp.data.add_(tau * ep.data) - for ep, tp in zip(self.conactor.parameters(), self.target_conactor.parameters()): - tp.data.mul_(1 - tau) - tp.data.add_(tau * ep.data) - for ep, tp in zip(self.qnetwork.parameters(), self.target_qnetwork.parameters()): - tp.data.mul_(1 - tau) - tp.data.add_(tau * ep.data) - - - class MPDQNPolicy(nn.Module): - def __init__(self, - observation_space, - action_space, - representation: nn.Module, - conactor_hidden_size: Sequence[int], - qnetwork_hidden_size: Sequence[int], - normalize: Optional[ModuleType] = None, - initialize: Optional[Callable[..., torch.Tensor]] = None, - activation: Optional[ModuleType] = None, - device: Optional[Union[str, int, torch.device]] = None): - super(MPDQNPolicy, self).__init__() - self.representation = representation - self.target_representation = copy.deepcopy(representation) - self.observation_space = observation_space - self.obs_size = self.observation_space.shape[0] - self.action_space = action_space - self.num_disact = self.action_space.spaces[0].n - self.conact_sizes = np.array([self.action_space.spaces[i].shape[0] for i in range(1, self.num_disact + 1)]) - self.conact_size = int(self.conact_sizes.sum()) - - self.qnetwork = BasicQhead(self.observation_space.shape[0] + self.conact_size, self.num_disact, - qnetwork_hidden_size, normalize, - initialize, torch.nn.modules.activation.ReLU, device) - self.conactor = ActorNet(self.observation_space.shape[0], self.conact_size, conactor_hidden_size, - initialize, torch.nn.modules.activation.ReLU, device) - self.target_conactor = copy.deepcopy(self.conactor) - self.target_qnetwork = copy.deepcopy(self.qnetwork) - - self.offsets = self.conact_sizes.cumsum() - self.offsets = np.insert(self.offsets, 0, 0) - - def Atarget(self, state): - target_conact = self.target_conactor(state) - return target_conact - - def con_action(self, state): - conaction = self.conactor(state) - return conaction - - def Qtarget(self, state, action): - batch_size = state.shape[0] - Q = [] - input_q = torch.cat((state, torch.zeros_like(action)), dim=1) - input_q = input_q.repeat(self.num_disact, 1) - for i in range(self.num_disact): - input_q[i * batch_size:(i + 1) * batch_size, - self.obs_size + self.offsets[i]: self.obs_size + self.offsets[i + 1]] \ - = action[:, self.offsets[i]:self.offsets[i + 1]] - eval_qall = self.target_qnetwork(input_q) - for i in range(self.num_disact): - eval_q = eval_qall[i * batch_size:(i + 1) * batch_size, i] - if len(eval_q.shape) == 1: - eval_q = eval_q.unsqueeze(1) - Q.append(eval_q) - Q = torch.cat(Q, dim=1) - return Q - - def Qeval(self, state, action): - batch_size = state.shape[0] - Q = [] - input_q = torch.cat((state, torch.zeros_like(action)), dim=1) - input_q = input_q.repeat(self.num_disact, 1) - for i in range(self.num_disact): - input_q[i * batch_size:(i + 1) * batch_size, - self.obs_size + self.offsets[i]: self.obs_size + self.offsets[i + 1]] \ - = action[:, self.offsets[i]:self.offsets[i + 1]] - eval_qall = self.qnetwork(input_q) - for i in range(self.num_disact): - eval_q = eval_qall[i * batch_size:(i + 1) * batch_size, i] - if len(eval_q.shape) == 1: - eval_q = eval_q.unsqueeze(1) - Q.append(eval_q) - Q = torch.cat(Q, dim=1) - return Q - - def Qpolicy(self, state): - conact = self.conactor(state) - batch_size = state.shape[0] - Q = [] - input_q = torch.cat((state, torch.zeros_like(conact)), dim=1) - input_q = input_q.repeat(self.num_disact, 1) - for i in range(self.num_disact): - input_q[i * batch_size:(i + 1) * batch_size, - self.obs_size + self.offsets[i]: self.obs_size + self.offsets[i + 1]] \ - = conact[:, self.offsets[i]:self.offsets[i + 1]] - eval_qall = self.qnetwork(input_q) - for i in range(self.num_disact): - eval_q = eval_qall[i * batch_size:(i + 1) * batch_size, i] - if len(eval_q.shape) == 1: - eval_q = eval_q.unsqueeze(1) - Q.append(eval_q) - Q = torch.cat(Q, dim=1) - return Q - - def soft_update(self, tau=0.005): - for ep, tp in zip(self.representation.parameters(), self.target_representation.parameters()): - tp.data.mul_(1 - tau) - tp.data.add_(tau * ep.data) - for ep, tp in zip(self.conactor.parameters(), self.target_conactor.parameters()): - tp.data.mul_(1 - tau) - tp.data.add_(tau * ep.data) - for ep, tp in zip(self.qnetwork.parameters(), self.target_qnetwork.parameters()): - tp.data.mul_(1 - tau) - tp.data.add_(tau * ep.data) - - - class SPDQNPolicy(nn.Module): - def __init__(self, - observation_space, - action_space, - representation: nn.Module, - conactor_hidden_size: Sequence[int], - qnetwork_hidden_size: Sequence[int], - normalize: Optional[ModuleType] = None, - initialize: Optional[Callable[..., torch.Tensor]] = None, - activation: Optional[ModuleType] = None, - device: Optional[Union[str, int, torch.device]] = None): - super(SPDQNPolicy, self).__init__() - self.representation = representation - self.target_representation = copy.deepcopy(representation) - self.observation_space = observation_space - self.action_space = action_space - self.num_disact = self.action_space.spaces[0].n - self.conact_sizes = np.array([self.action_space.spaces[i].shape[0] for i in range(1, self.num_disact + 1)]) - self.conact_size = int(self.conact_sizes.sum()) - self.qnetwork = nn.ModuleList() - for k in range(self.num_disact): - self.qnetwork.append( - BasicQhead(self.observation_space.shape[0] + self.conact_sizes[k], 1, qnetwork_hidden_size, normalize, - initialize, torch.nn.modules.activation.ReLU, device)) - self.conactor = ActorNet(self.observation_space.shape[0], self.conact_size, conactor_hidden_size, - initialize, torch.nn.modules.activation.ReLU, device) - self.target_conactor = copy.deepcopy(self.conactor) - self.target_qnetwork = copy.deepcopy(self.qnetwork) - - self.offsets = self.conact_sizes.cumsum() - self.offsets = np.insert(self.offsets, 0, 0) - - def Atarget(self, state): - target_conact = self.target_conactor(state) - return target_conact - - def con_action(self, state): - conaction = self.conactor(state) - return conaction - - def Qtarget(self, state, action): - target_Q = [] - for i in range(self.num_disact): - conact = action[:, self.offsets[i]:self.offsets[i + 1]] - input_q = torch.cat((state, conact), dim=1) - eval_q = self.target_qnetwork[i](input_q) - target_Q.append(eval_q) - target_Q = torch.cat(target_Q, dim=1) - return target_Q - - def Qeval(self, state, action): - Q = [] - for i in range(self.num_disact): - conact = action[:, self.offsets[i]:self.offsets[i + 1]] - input_q = torch.cat((state, conact), dim=1) - eval_q = self.qnetwork[i](input_q) - Q.append(eval_q) - Q = torch.cat(Q, dim=1) - return Q - - def Qpolicy(self, state): - conacts = self.conactor(state) - Q = [] - for i in range(self.num_disact): - conact = conacts[:, self.offsets[i]:self.offsets[i + 1]] - input_q = torch.cat((state, conact), dim=1) - eval_q = self.qnetwork[i](input_q) - Q.append(eval_q) - Q = torch.cat(Q, dim=1) - return Q - - def soft_update(self, tau=0.005): - for ep, tp in zip(self.representation.parameters(), self.target_representation.parameters()): - tp.data.mul_(1 - tau) - tp.data.add_(tau * ep.data) - for ep, tp in zip(self.conactor.parameters(), self.target_conactor.parameters()): - tp.data.mul_(1 - tau) - tp.data.add_(tau * ep.data) - for ep, tp in zip(self.qnetwork.parameters(), self.target_qnetwork.parameters()): - tp.data.mul_(1 - tau) - tp.data.add_(tau * ep.data) - - - class DRQNPolicy(nn.Module): - def __init__(self, - action_space: Discrete, - representation: nn.Module, - **kwargs): - super(DRQNPolicy, self).__init__() - self.device = kwargs['device'] - self.recurrent_layer_N = kwargs['recurrent_layer_N'] - self.rnn_hidden_dim = kwargs['recurrent_hidden_size'] - self.action_dim = action_space.n - self.representation = representation - self.target_representation = copy.deepcopy(representation) - self.representation_info_shape = self.representation.output_shapes - kwargs["input_dim"] = self.representation.output_shapes['state'][0] - kwargs["action_dim"] = self.action_dim - self.lstm = True if kwargs["rnn"] == "LSTM" else False - self.cnn = True if self.representation._get_name() == "Basic_CNN" else False - self.eval_Qhead = BasicRecurrent(**kwargs) - self.target_Qhead = copy.deepcopy(self.eval_Qhead) - - def forward(self, observation: Union[np.ndarray, dict], *rnn_hidden: torch.Tensor): - if self.cnn: - obs_shape = observation.shape - outputs = self.representation(observation.reshape((-1,) + obs_shape[-3:])) - outputs['state'] = outputs['state'].reshape(obs_shape[0:-3] + (-1,)) - else: - outputs = self.representation(observation) - if self.lstm: - hidden_states, cell_states, evalQ = self.eval_Qhead(outputs['state'], rnn_hidden[0], rnn_hidden[1]) - else: - hidden_states, evalQ = self.eval_Qhead(outputs['state'], rnn_hidden[0]) - cell_states = None - argmax_action = evalQ[:, -1].argmax(dim=-1) - return outputs, argmax_action, evalQ, (hidden_states, cell_states) - - def target(self, observation: Union[np.ndarray, dict], *rnn_hidden: torch.Tensor): - if self.cnn: - obs_shape = observation.shape - outputs = self.representation(observation.reshape((-1,) + obs_shape[-3:])) - outputs['state'] = outputs['state'].reshape(obs_shape[0:-3] + (-1,)) - else: - outputs = self.representation(observation) - if self.lstm: - hidden_states, cell_states, targetQ = self.target_Qhead(outputs['state'], rnn_hidden[0], rnn_hidden[1]) - else: - hidden_states, targetQ = self.target_Qhead(outputs['state'], rnn_hidden[0]) - cell_states = None - argmax_action = targetQ.argmax(dim=-1) - return outputs, argmax_action, targetQ.detach(), (hidden_states, cell_states) - - def init_hidden(self, batch): - hidden_states = torch.zeros(size=(self.recurrent_layer_N, batch, self.rnn_hidden_dim)).to(self.device) - cell_states = torch.zeros_like(hidden_states).to(self.device) if self.lstm else None - return hidden_states, cell_states + """ + COMA: Counterfactual Multi-Agent Policy Gradients + Paper link: https://ojs.aaai.org/index.php/AAAI/article/view/11794 + Implementation: Pytorch + """ + import torch + + from xuance.torch.learners import * + + + class COMA_Learner(LearnerMAS): + def __init__(self, + config: Namespace, + policy: nn.Module, + optimizer: Sequence[torch.optim.Optimizer], + scheduler: Sequence[torch.optim.lr_scheduler._LRScheduler] = None, + device: Optional[Union[int, str, torch.device]] = None, + model_dir: str = "./", + gamma: float = 0.99, + sync_frequency: int = 100 + ): + self.gamma = gamma + self.td_lambda = config.td_lambda + self.sync_frequency = sync_frequency + self.use_global_state = config.use_global_state + self.mse_loss = nn.MSELoss() + super(COMA_Learner, self).__init__(config, policy, optimizer, scheduler, device, model_dir) + self.optimizer = { + 'actor': optimizer[0], + 'critic': optimizer[1] + } + self.scheduler = { + 'actor': scheduler[0], + 'critic': scheduler[1] + } + self.iterations_actor = self.iterations + self.iterations_critic = 0 + + def update(self, sample, epsilon=0.0): + self.iterations += 1 + state = torch.Tensor(sample['state']).to(self.device) + obs = torch.Tensor(sample['obs']).to(self.device) + actions = torch.Tensor(sample['actions']).to(self.device) + actions_onehot = torch.Tensor(sample['actions_onehot']).to(self.device) + targets = torch.Tensor(sample['returns']).squeeze(-1).to(self.device) + agent_mask = torch.Tensor(sample['agent_mask']).float().to(self.device) + batch_size = obs.shape[0] + IDs = torch.eye(self.n_agents).unsqueeze(0).expand(batch_size, -1, -1).to(self.device) + + # build critic input + actions_in = actions_onehot.unsqueeze(1).reshape(batch_size, 1, -1).repeat(1, self.n_agents, 1) + actions_in_mask = 1 - torch.eye(self.n_agents, device=self.device) + actions_in_mask = actions_in_mask.reshape(-1, 1).repeat(1, self.dim_act).reshape(self.n_agents, -1) + actions_in = actions_in * actions_in_mask.unsqueeze(0) + if self.use_global_state: + state = state.unsqueeze(1).repeat(1, self.n_agents, 1) + critic_in = torch.concat([state, obs, actions_in], dim=-1) + else: + critic_in = torch.concat([obs, actions_in]) + # get critic value + _, q_eval = self.policy.get_values(critic_in) + q_eval_a = q_eval.gather(-1, actions.unsqueeze(-1).long()).squeeze(-1) + q_eval_a *= agent_mask + targets *= agent_mask + loss_c = ((q_eval_a - targets.detach()) ** 2).sum() / agent_mask.sum() + self.optimizer['critic'].zero_grad() + loss_c.backward() + grad_norm_critic = torch.nn.utils.clip_grad_norm_(self.policy.parameters_critic, self.args.clip_grad) + self.optimizer['critic'].step() + if self.iterations_critic % self.sync_frequency == 0: + self.policy.copy_target() + self.iterations_critic += 1 + + if self.scheduler['critic'] is not None: + self.scheduler['critic'].step() + + # calculate baselines + _, pi_probs = self.policy(obs, IDs, epsilon=epsilon) + baseline = (pi_probs * q_eval).sum(-1).detach() + + pi_a = pi_probs.gather(-1, actions.unsqueeze(-1).long()).squeeze(-1) + log_pi_a = torch.log(pi_a) + advantages = (q_eval_a - baseline).detach() + log_pi_a *= agent_mask + advantages *= agent_mask + loss_coma = -(advantages * log_pi_a).sum() / agent_mask.sum() + + self.optimizer['actor'].zero_grad() + loss_coma.backward() + grad_norm_actor = torch.nn.utils.clip_grad_norm_(self.policy.parameters_actor, self.args.clip_grad) + self.optimizer['actor'].step() + + if self.scheduler['actor'] is not None: + self.scheduler['actor'].step() + + # Logger + lr_a = self.optimizer['actor'].state_dict()['param_groups'][0]['lr'] + lr_c = self.optimizer['critic'].state_dict()['param_groups'][0]['lr'] + + info = { + "learning_rate_actor": lr_a, + "learning_rate_critic": lr_c, + "actor_loss": loss_coma.item(), + "critic_loss": loss_c.item(), + "advantage": advantages.mean().item(), + "actor_gradient_norm": grad_norm_actor.item(), + "critic_gradient_norm": grad_norm_critic.item() + } + + return info + + def update_recurrent(self, sample, epsilon=0.0): + self.iterations += 1 + state = torch.Tensor(sample['state']).to(self.device) + obs = torch.Tensor(sample['obs']).to(self.device) + actions = torch.Tensor(sample['actions']).to(self.device) + actions_onehot = torch.Tensor(sample['actions_onehot']).to(self.device) + targets = torch.Tensor(sample['returns']).squeeze(-1).to(self.device) + avail_actions = torch.Tensor(sample['avail_actions']).float().to(self.device) + filled = torch.Tensor(sample['filled']).float().to(self.device) + batch_size = obs.shape[0] + episode_length = actions.shape[2] + IDs = torch.eye(self.n_agents).unsqueeze(1).unsqueeze(0).expand(batch_size, -1, episode_length + 1, -1).to( + self.device) + + # build critic input + actions_in = actions_onehot.transpose(1, 2).reshape(batch_size, episode_length, -1) + actions_in = actions_in.unsqueeze(1).repeat(1, self.n_agents, 1, 1) + actions_in_mask = 1 - torch.eye(self.n_agents, device=self.device) + actions_in_mask = actions_in_mask.view(-1, 1).repeat(1, self.dim_act).view(self.n_agents, -1) + actions_in_mask = actions_in_mask.unsqueeze(1).repeat(1, episode_length, 1) + actions_in = actions_in * actions_in_mask + if self.use_global_state: + state = state[:, :-1].unsqueeze(1).repeat(1, self.n_agents, 1, 1) + critic_in = torch.concat([state, obs[:, :, :-1], actions_in], dim=-1) + else: + critic_in = torch.concat([obs[:, :, :-1], actions_in], dim=-1) + + # get critic value + _, q_eval = self.policy.get_values(critic_in) + q_eval_a = q_eval.gather(-1, actions.unsqueeze(-1).long()).squeeze(-1) + filled_n = filled.unsqueeze(1).expand(-1, self.n_agents, -1, -1).squeeze(-1) + td_errors = q_eval_a - targets.detach() + td_errors *= filled_n + loss_c = (td_errors ** 2).sum() / filled_n.sum() + self.optimizer['critic'].zero_grad() + loss_c.backward() + grad_norm_critic = torch.nn.utils.clip_grad_norm_(self.policy.parameters_critic, self.args.clip_grad) + self.optimizer['critic'].step() + if self.iterations_critic % self.sync_frequency == 0: + self.policy.copy_target() + self.iterations_critic += 1 + + if self.scheduler['critic'] is not None: + self.scheduler['critic'].step() + + # calculate baselines + rnn_hidden_actor = self.policy.representation.init_hidden(batch_size * self.n_agents) + _, pi_probs = self.policy(obs[:, :, :-1].reshape(-1, episode_length, self.dim_obs), + IDs[:, :, :-1].reshape(-1, episode_length, self.n_agents), + *rnn_hidden_actor, + avail_actions=avail_actions[:, :, :-1].reshape(-1, episode_length, self.dim_act), + epsilon=epsilon) + pi_probs = pi_probs.reshape(batch_size, self.n_agents, episode_length, self.dim_act) + baseline = (pi_probs * q_eval).sum(-1) + + pi_a = pi_probs.gather(-1, actions.unsqueeze(-1).long()).squeeze(-1) + log_pi_a = torch.log(pi_a) + advantages = (q_eval_a - baseline).detach() + loss_coma = -(advantages * log_pi_a * filled_n).sum() / filled_n.sum() + + self.optimizer['actor'].zero_grad() + loss_coma.backward() + grad_norm_actor = torch.nn.utils.clip_grad_norm_(self.policy.parameters_actor, self.args.clip_grad) + self.optimizer['actor'].step() + + if self.scheduler['actor'] is not None: + self.scheduler['actor'].step() + + # Logger + lr_a = self.optimizer['actor'].state_dict()['param_groups'][0]['lr'] + lr_c = self.optimizer['critic'].state_dict()['param_groups'][0]['lr'] + + info = { + "learning_rate_actor": lr_a, + "learning_rate_critic": lr_c, + "actor_loss": loss_coma.item(), + "critic_loss": loss_c.item(), + "advantage": advantages.mean().item(), + "actor_gradient_norm": grad_norm_actor.item(), + "critic_gradient_norm": grad_norm_critic.item() + } + + return info - def init_hidden_item(self, rnn_hidden, i): - if self.lstm: - rnn_hidden[0][:, i] = torch.zeros(size=(self.recurrent_layer_N, self.rnn_hidden_dim)).to(self.device) - rnn_hidden[1][:, i] = torch.zeros(size=(self.recurrent_layer_N, self.rnn_hidden_dim)).to(self.device) - return rnn_hidden - else: - rnn_hidden[:, i] = torch.zeros(size=(self.recurrent_layer_N, self.rnn_hidden_dim)).to(self.device) - return rnn_hidden - def copy_target(self): - for ep, tp in zip(self.representation.parameters(), self.target_representation.parameters()): - tp.data.copy_(ep) - for ep, tp in zip(self.eval_Qhead.parameters(), self.target_Qhead.parameters()): - tp.data.copy_(ep) .. group-tab:: TensorFlow