Skip to content

Commit

Permalink
[BugFix] Fix failing tests
Browse files Browse the repository at this point in the history
ghstack-source-id: cacee5ae909f4b6109dd887d240a8956acc1bd38
Pull Request resolved: #2582
  • Loading branch information
vmoens committed Nov 18, 2024
1 parent 408cf7d commit 4d71592
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ jobs:
REF_TYPE=${{ github.ref_type }}
REF_NAME=${{ github.ref_name }}
apt-get update
apt-get install rsync -y
if [[ "${REF_TYPE}" == branch ]]; then
if [[ "${REF_NAME}" == main ]]; then
Expand Down
13 changes: 6 additions & 7 deletions torchrl/modules/tensordict_module/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,11 +312,10 @@ def step(self, frames: int = 1) -> None:
for _ in range(frames):
self.sigma.data.copy_(
torch.maximum(
self.sigma_end(
self.sigma
- (self.sigma_init - self.sigma_end) / self.annealing_num_steps
),
)
self.sigma_end,
self.sigma
- (self.sigma_init - self.sigma_end) / self.annealing_num_steps,
),
)

def _add_noise(self, action: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -946,8 +945,8 @@ def _make_noise_pair(
noise = tensordict.get(self.noise_key).clone()
steps = tensordict.get(self.steps_key).clone()
if is_init is not None:
noise = torch.masked_fill(noise, is_init, 0)
steps = torch.masked_fill(steps, is_init, 0)
noise = torch.masked_fill(noise, expand_right(is_init, noise.shape), 0)
steps = torch.masked_fill(steps, expand_right(is_init, steps.shape), 0)
return noise, steps

def add_sample(
Expand Down

0 comments on commit 4d71592

Please sign in to comment.