Skip to content

Commit

Permalink
[BugFix] fix dreamer actor (#1697)
Browse files Browse the repository at this point in the history
Co-authored-by: vmoens <[email protected]>
  • Loading branch information
FrankTianTT and vmoens authored Nov 15, 2023
1 parent 02ff00d commit e1eb69d
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion torchrl/trainers/helpers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,7 @@ def _dreamer_make_actor_sim(action_key, proof_environment, actor_module):
out_keys=[action_key],
default_interaction_type=InteractionType.RANDOM,
distribution_class=TanhNormal,
distribution_kwargs={"tanh_loc": True},
spec=CompositeSpec(**{action_key: proof_environment.action_spec}),
),
)
Expand Down Expand Up @@ -703,8 +704,9 @@ def _dreamer_make_actor_real(
SafeProbabilisticModule(
in_keys=["loc", "scale"],
out_keys=[action_key],
default_interaction_type=InteractionType.RANDOM,
default_interaction_type=InteractionType.MODE,
distribution_class=TanhNormal,
distribution_kwargs={"tanh_loc": True},
spec=CompositeSpec(
**{action_key: proof_environment.action_spec.to("cpu")}
),
Expand Down

2 comments on commit e1eb69d

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: e1eb69d Previous: 02ff00d Ratio
benchmarks/test_objectives_benchmarks.py::test_values[td0_return_estimate-False-False] 1951.445613324819 iter/sec (stddev: 0.00024149821207440126) 4986.895933579403 iter/sec (stddev: 0.00002776360724696224) 2.56

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'GPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: e1eb69d Previous: 02ff00d Ratio
benchmarks/test_objectives_benchmarks.py::test_values[td0_return_estimate-False-False] 1768.8045771065874 iter/sec (stddev: 0.00007214817209200195) 5604.701941794002 iter/sec (stddev: 0.000012523831818547346) 3.17
benchmarks/test_objectives_benchmarks.py::test_gae_speed[vec_generalized_advantage_estimate-True-1-512] 219.46478223047686 iter/sec (stddev: 0.0009249430285535704) 577.7983464651579 iter/sec (stddev: 0.000025674620522692584) 2.63

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.