Skip to content

Commit

Permalink
polish
Browse files Browse the repository at this point in the history
  • Loading branch information
rongkunxue committed Jun 21, 2024
1 parent 6e3cf36 commit d536ab1
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 22 deletions.
53 changes: 45 additions & 8 deletions ding/policy/qtransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import wandb

# from einops import pack, rearrange

from ding.model import model_wrap
from ding.torch_utils import Adam, to_device
from ding.utils import POLICY_REGISTRY
Expand Down Expand Up @@ -202,6 +201,7 @@ def _init_learn(self) -> None:

# Algorithm config
self._gamma = self._cfg.learn.discount_factor
self._action_dim = self._cfg.model.action_dim

# Init auto alpha
if self._cfg.learn.auto_alpha:
Expand Down Expand Up @@ -306,6 +306,7 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
# ignore_done=self._cfg.learn.ignore_done,
# use_nstep=False,
# )

def discretization(x):
self._action_values = torch.tensor(self._action_values)
indices = torch.zeros_like(x, dtype=torch.long, device=x.device)
Expand All @@ -330,8 +331,10 @@ def discretization(x):
next_state = data["next_state"] # torch.Size([2048, 10, 17])
reward = data["reward"][:, -1] # torch.Size([2048])
done = data["done"][:, -1] # torch.Size([2048])
action = data["action"] # torch.Size([2048, 6, 256])
next_action = data["next_action"] # torch.Size([2048, 6, 256])
action = data["action"]
next_action = data["next_action"]

action = self._get_actions(state)

q_pred_all_actions = self._learn_model.forward(state, action=action)[:, 1:, :]
# torch.Size([2048, 6, 256])
Expand Down Expand Up @@ -377,7 +380,7 @@ def batch_select_indices(t, indices):
losses_last_action = F.mse_loss(q_pred_last_action, q_target_last_action)
td_loss = losses_all_actions_but_last + losses_last_action
td_loss.mean()
loss = td_loss + conservative_loss
loss = td_loss + conservative_loss * 0
self._optimizer_q.zero_grad()
loss.backward()
self._optimizer_q.step()
Expand Down Expand Up @@ -411,14 +414,48 @@ def batch_select_indices(t, indices):
"q_real": q_pred.mean().item(),
},
)
return loss, q_pred_all_actions.mean().item()
return {
"td_error": loss.item(),
"policy_loss": q_pred_all_actions.mean().item(),
}

def _get_actions(self, obs):

action = self._eval_model.get_actions(obs)
action = 2.0 * action / (1.0 * self._action_bin) - 1.0
action_bins = None
action_bins = torch.full(
(obs.size(0), self._action_dim), -1, dtype=torch.long, device=obs.device
)
for action_idx in range(self._action_dim):
if action_idx == 0:
q_values = self._eval_model.forward(obs)
else:
q_values = self._eval_model.forward(
obs, action=action_bins[:, :action_idx]
)[:, action_idx-1:action_idx, :]
selected_action_bins = q_values.argmax(dim=-1)
action_bins[:, action_idx] = selected_action_bins.squeeze()
action = 2.0 * action_bins.float() / (1.0 * self._action_bin) - 1.0
return action

def _monitor_vars_learn(self) -> List[str]:
"""
Overview:
Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \
as text logger, tensorboard logger, will use these keys to save the corresponding data.
Returns:
- necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged.
"""
return [
"value_loss" "alpha_loss",
"policy_loss",
"critic_loss",
"cur_lr_q",
"cur_lr_p",
"target_q_value",
"alpha",
"td_error",
"transformed_log_prob",
]

# def _monitor_vars_learn(self) -> List[str]:
# """
# Overview:
Expand Down
22 changes: 8 additions & 14 deletions qtransformer/algorithm/serial_entry_qtransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,7 @@ def serial_pipeline_offline(
wandb.init(**cfg.wandb)
config = merge_two_dicts_into_newone(EasyDict(wandb.config), cfg)
wandb.config.update(config)

if get_rank() == 0:
tb_logger = SummaryWriter(
os.path.join("./{}/log/".format(cfg.exp_name), "serial")
)
else:
tb_logger = None
tb_logger = SummaryWriter(os.path.join("./{}/log/".format(cfg.exp_name), "serial"))
learner = BaseLearner(
cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name
)
Expand All @@ -139,14 +133,14 @@ def serial_pipeline_offline(
for train_data in dataloader:
learner.train(train_data)

if evaluator.should_eval(learner.train_iter):
stop, eval_info = evaluator.eval(
learner.save_checkpoint, learner.train_iter
)
# if evaluator.should_eval(learner.train_iter):
# stop, eval_info = evaluator.eval(
# learner.save_checkpoint, learner.train_iter
# )

if stop or learner.train_iter >= max_train_iter:
stop = True
break
# if stop or learner.train_iter >= max_train_iter:
# stop = True
# break

learner.call_hook("after_run")
if get_rank() == 0:
Expand Down

0 comments on commit d536ab1

Please sign in to comment.