Skip to content

Commit

Permalink
[BugFix] Fix gym benchmark (#1619)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Oct 10, 2023
1 parent 90ad21c commit 502a2e6
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions benchmarks/ecosystem/gym_env_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,14 @@

if __name__ == "__main__":
for envname in [
"HalfCheetah-v4",
"CartPole-v1",
"HalfCheetah-v4",
"myoHandReachRandom-v0",
"ALE/Breakout-v5",
"CartPole-v1",
]:
# the number of collectors won't affect the resources, just impacts how the envs are split in sub-sub-processes
for num_workers, num_collectors in zip((8, 16, 32, 64), (2, 4, 8, 8)):
with open(
f"atari_{envname}_{num_workers}.txt".replace("/", "-"), "w+"
) as log:
for num_workers, num_collectors in zip((32, 64, 8, 16), (8, 8, 2, 4)):
with open(f"{envname}_{num_workers}.txt".replace("/", "-"), "w+") as log:
if "myo" in envname:
gym_backend = "gym"
else:
Expand Down Expand Up @@ -219,7 +216,7 @@ def make_env(

penv = EnvCreator(
lambda num_workers=num_workers // num_collectors: make_env(
num_workers
num_workers=num_workers
)
)
collector = MultiaSyncDataCollector(
Expand Down Expand Up @@ -306,7 +303,7 @@ def make_env(

penv = EnvCreator(
lambda num_workers=num_workers // num_collectors: make_env(
num_workers
num_workers=num_workers
)
)
collector = MultiSyncDataCollector(
Expand Down

0 comments on commit 502a2e6

Please sign in to comment.