Skip to content

Commit

Permalink
fix(zjow): fix complex obs demo for ppo pipeline (#786)
Browse files Browse the repository at this point in the history
* fix complex obs demo for ppo pipeline

* polish code

* polish code

* prevent wandb version error

* polish code
  • Loading branch information
zjowowen authored Apr 7, 2024
1 parent 6c36145 commit 7d05491
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 4 deletions.
10 changes: 10 additions & 0 deletions .github/workflows/unit_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ jobs:
python -m pip install .
python -m pip install ".[test,k8s]"
python -m pip install transformers
if python --version | grep -q "Python 3.7"; then
python -m pip install wandb==0.16.4
else
echo "Python version is not 3.7, skipping wandb installation"
fi
./ding/scripts/install-k8s-tools.sh
make unittest
- name: Upload coverage to Codecov
Expand Down Expand Up @@ -55,5 +60,10 @@ jobs:
python -m pip install .
python -m pip install ".[test,k8s]"
python -m pip install transformers
if python --version | grep -q "Python 3.7"; then
python -m pip install wandb==0.16.4
else
echo "Python version is not 3.7, skipping wandb installation"
fi
./ding/scripts/install-k8s-tools.sh
make benchmark
7 changes: 6 additions & 1 deletion ding/example/ppo_with_complex_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@
cuda=True,
action_space='discrete',
model=dict(
obs_shape=None,
obs_shape=dict(
key_0=dict(k1=(), k2=()),
key_1=(5, 10),
key_2=(10, 10, 3),
key_3=(2, ),
),
action_shape=2,
action_space='discrete',
critic_head_hidden_size=138,
Expand Down
5 changes: 2 additions & 3 deletions ding/framework/middleware/functional/advantage_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ding.rl_utils import gae, gae_data, get_train_sample
from ding.framework import task
from ding.utils.data import ttorch_collate
from ding.utils.dict_helper import convert_easy_dict_to_dict
from ding.torch_utils import to_device

if TYPE_CHECKING:
Expand All @@ -33,10 +34,8 @@ def gae_estimator(cfg: EasyDict, policy: Policy, buffer_: Optional[Buffer] = Non
# Unify the shape of obs and action
obs_shape = cfg['policy']['model']['obs_shape']
obs_shape = torch.Size(torch.tensor(obs_shape)) if isinstance(obs_shape, list) \
else ttorch.size.Size(convert_easy_dict_to_dict(obs_shape)) if isinstance(obs_shape, dict) \
else torch.Size(torch.tensor(obs_shape).unsqueeze(0))
action_shape = cfg['policy']['model']['action_shape']
action_shape = torch.Size(torch.tensor(action_shape)) if isinstance(action_shape, list) \
else torch.Size(torch.tensor(action_shape).unsqueeze(0))

def _gae(ctx: "OnlineRLContext"):
"""
Expand Down
1 change: 1 addition & 0 deletions ding/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
LimitedSpaceContainer, deep_merge_dicts, set_pkg_seed, flatten_dict, one_time_warning, split_data_generator, \
RunningMeanStd, make_key_as_identifier, remove_illegal_item
from .design_helper import SingletonMetaclass
from .dict_helper import convert_easy_dict_to_dict
from .file_helper import read_file, save_file, remove_file
from .import_helper import try_import_ceph, try_import_mc, try_import_link, import_module, try_import_redis, \
try_import_rediscluster
Expand Down
13 changes: 13 additions & 0 deletions ding/utils/dict_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from easydict import EasyDict


def convert_easy_dict_to_dict(easy_dict: EasyDict) -> dict:
"""
Overview:
Convert an EasyDict object to a dict object recursively.
Arguments:
- easy_dict (:obj:`EasyDict`): The EasyDict object to be converted.
Returns:
- dict: The converted dict object.
"""
return {k: convert_easy_dict_to_dict(v) if isinstance(v, EasyDict) else v for k, v in easy_dict.items()}

0 comments on commit 7d05491

Please sign in to comment.