Skip to content

Commit

Permalink
[BugFix] Fix failing tests
Browse files Browse the repository at this point in the history
ghstack-source-id: 1c34b22fa4d7253ad318bd061da14a671be9a4d6
Pull Request resolved: #2582
  • Loading branch information
vmoens committed Nov 18, 2024
1 parent 408cf7d commit e6b14e5
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ jobs:
REF_TYPE=${{ github.ref_type }}
REF_NAME=${{ github.ref_name }}
apt-get update
apt-get install rsync -y
if [[ "${REF_TYPE}" == branch ]]; then
if [[ "${REF_NAME}" == main ]]; then
Expand Down
3 changes: 3 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2076,7 +2076,10 @@ def test_transform_rb(self, rbclass):
):
td = rb.sample(10)

@retry(AssertionError, tries=10, delay=0)
def test_collector_match(self):
torch.manual_seed(0)

# The counter in the collector should match the one from the transform
t = TrajCounter()

Expand Down
13 changes: 6 additions & 7 deletions torchrl/modules/tensordict_module/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,11 +312,10 @@ def step(self, frames: int = 1) -> None:
for _ in range(frames):
self.sigma.data.copy_(
torch.maximum(
self.sigma_end(
self.sigma
- (self.sigma_init - self.sigma_end) / self.annealing_num_steps
),
)
self.sigma_end,
self.sigma
- (self.sigma_init - self.sigma_end) / self.annealing_num_steps,
),
)

def _add_noise(self, action: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -946,8 +945,8 @@ def _make_noise_pair(
noise = tensordict.get(self.noise_key).clone()
steps = tensordict.get(self.steps_key).clone()
if is_init is not None:
noise = torch.masked_fill(noise, is_init, 0)
steps = torch.masked_fill(steps, is_init, 0)
noise = torch.masked_fill(noise, expand_right(is_init, noise.shape), 0)
steps = torch.masked_fill(steps, expand_right(is_init, steps.shape), 0)
return noise, steps

def add_sample(
Expand Down

0 comments on commit e6b14e5

Please sign in to comment.