Skip to content

Commit

Permalink
fix sampler pytorch/rl#1762
Browse files Browse the repository at this point in the history
  • Loading branch information
nicklashansen committed Dec 25, 2023
1 parent ca4dfa1 commit 2f86a1e
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
8 changes: 6 additions & 2 deletions tdmpc2/common/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,9 @@ def _get_stop_and_length(self, storage, fallback=True):
and self._used_traj_key[0] == "_data"
)
vals = self._find_start_stop_traj(trajectory=trajectory[: len(storage)])
return self._cache.setdefault("stop-and-length", vals)
if self.cache_values:
self._cache["stop-and-length"] = vals
return vals
except KeyError:
if fallback:
self._fetch_traj = False
Expand All @@ -257,7 +259,9 @@ def _get_stop_and_length(self, storage, fallback=True):
and self._used_end_key[0] == "_data"
)
vals = self._find_start_stop_traj(end=done.squeeze())[: len(storage)]
return self._cache.setdefault("stop-and-length", vals)
if self.cache_values:
self._cache["stop-and-length"] = vals
return vals
except KeyError:
if fallback:
self._fetch_traj = True
Expand Down
3 changes: 1 addition & 2 deletions tdmpc2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ def train(cfg: dict):
cfg=cfg,
env=make_env(cfg),
agent=TDMPC2(cfg),
buffer=CropBuffer(cfg),
# buffer=SliceBuffer(cfg),
buffer=SliceBuffer(cfg),
logger=Logger(cfg),
)
trainer.train()
Expand Down

0 comments on commit 2f86a1e

Please sign in to comment.