Skip to content

Commit

Permalink
polish
Browse files Browse the repository at this point in the history
  • Loading branch information
rongkunxue committed Jun 20, 2024
1 parent 54688fa commit d8b3868
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions ding/policy/qtransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,20 +330,31 @@ def discretization(x):
action = data["action"]
next_action = data["next_action"]

q = self._learn_model.forward(state, action=action)
q_pred = self._learn_model.forward(state, action=action)

with torch.no_grad():
q_next_target = self._target_model.forward(next_state, actions=next_action)
q_next_target = q_next_target.max(dim=-1).values
q_target = self._target_model.forward(state, actions=action)
q_target = q_target.max(dim=-1).values
q_pred_rest_actions, q_pred_last_action = q_pred[..., :-1], q_pred[..., -1]
q_target_rest_actions = q_target[..., 1:]
q_next_first_action = q_next_target[..., 0]
losses_all_actions_but_last = F.mse_loss(
q_pred_rest_actions, q_target_rest_actions, reduction="none"
)
q_target_last_action = reward[-1] + 0.99 * q_next_first_action
losses_last_action = F.mse_loss(
q_pred_last_action, q_target_last_action, reduction="none"
)
td_loss = losses_all_actions_but_last + losses_last_action

self._optimizer_q.zero_grad()
loss_dict["loss"].backward()
td_loss.backward()
self._optimizer_q.step()

self._forward_learn_cnt += 1
self._target_model.update(self._learn_model.state_dict())

return loss_dict
return loss

def _get_actions(self, obs):

Expand Down

0 comments on commit d8b3868

Please sign in to comment.