-
Notifications
You must be signed in to change notification settings - Fork 373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature(xrk): add q-transformer #783
base: main
Are you sure you want to change the base?
Conversation
ding/policy/__init__.py
Outdated
@@ -19,6 +19,7 @@ | |||
from .ppo import PPOPolicy, PPOPGPolicy, PPOOffPolicy | |||
from .sac import SACPolicy, DiscreteSACPolicy, SQILSACPolicy | |||
from .cql import CQLPolicy, DiscreteCQLPolicy | |||
from .qtransformer import QtransformerPolicy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QTransformerPolicy
from ding.entry import serial_pipeline_offline | ||
from ding.config import read_config | ||
from pathlib import Path | ||
from ding.model.template.qtransformer import QTransformer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import from the secondary directory, such as:
from ding.model import QTransformer
alpha=0.2, | ||
discount_factor_gamma=0.9, | ||
min_reward = 0.1, | ||
auto_alpha=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove unused fields like this
ding/policy/qtransformer.py
Outdated
update_type='momentum', | ||
update_kwargs={'theta': self._cfg.learn.target_theta} | ||
) | ||
self._low = np.array(self._cfg.other["low"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't need low and high here, We always think that the action value range in the policy is [-1,1]
cuda=True, | ||
model=dict( | ||
num_actions = 3, | ||
action_bins = 256, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this action_bins field is not used in policy
ding/policy/qtransformer.py
Outdated
selected = t.gather(-1, indices) | ||
return rearrange(selected, '... 1 -> ...') | ||
|
||
def _discretize_action(self, actions): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can optimize this for loop:
action_values = np.linspace(-1, 1, 8)[np.newaxis, ...].repeat(4, 0)
action_values = torch.as_tensor(action_values).to(self._device)
diff = (actions.unsqueeze(-1) - action_values.unsqueeze(0)) ** 2
indices = diff.argmin(-1)
ding/policy/qtransformer.py
Outdated
actions = data['action'] | ||
|
||
#get q | ||
num_timesteps, device = states.shape[1], states.device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use self._device
, which is the default member variable of Policy
ding/policy/qtransformer.py
Outdated
import torch | ||
import torch.nn.functional as F | ||
from torch.distributions import Normal, Independent | ||
from ema_pytorch import EMA |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove unused third party libraries
ding/policy/qtransformer.py
Outdated
|
||
from pathlib import Path | ||
from functools import partial | ||
from contextlib import nullcontext |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
polish imports
ding/policy/qtransformer.py
Outdated
|
||
from torchtyping import TensorType | ||
|
||
from einops import rearrange, repeat, pack, unpack |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add einops
in setup.py
ding/policy/qtransformer.py
Outdated
from einops import rearrange, repeat, pack, unpack | ||
from einops.layers.torch import Rearrange | ||
|
||
from beartype import beartype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we will not use beartype
to validate runtime types in the current version, thus remove it in this PR
ding/model/template/qtransformer.py
Outdated
@@ -0,0 +1,753 @@ | |||
from random import random | |||
from functools import partial, cache |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cache
is the new feature in python3.9, for compatibility, you should implement it as follows:
try:
from functools import cache # only in Python >= 3.9
except ImportError:
from functools import lru_cache
cache = lru_cache(maxsize=None)
…tput; more pannel to see
Description
Related Issue
TODO
Check List