Skip to content

Commit

Permalink
[Algorithm] Update discrete SAC example (#1745)
Browse files Browse the repository at this point in the history
Co-authored-by: vmoens <[email protected]>
  • Loading branch information
BY571 and vmoens authored Jan 9, 2024
1 parent eb603ab commit 6c68f7e
Show file tree
Hide file tree
Showing 5 changed files with 493 additions and 335 deletions.
28 changes: 14 additions & 14 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,20 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \
env.name=Pendulum-v1 \
network.device=cuda:0 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/discrete_sac/discrete_sac.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=1 \
collector.device=cuda:0 \
optim.batch_size=10 \
optim.utd_ratio=1 \
network.device=cuda:0 \
optim.batch_size=10 \
optim.utd_ratio=1 \
replay_buffer.size=120 \
env.name=CartPole-v1 \
logger.backend=
# logger.record_video=True \
# logger.record_frames=4 \
python .github/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreamer.py \
Expand Down Expand Up @@ -246,20 +260,6 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \
logger.record_frames=4 \
buffer.size=120 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=1 \
collector.device=cuda:0 \
optim.batch_size=10 \
optim.utd_ratio=1 \
network.device=cuda:0 \
optim.batch_size=10 \
optim.utd_ratio=1 \
replay_buffer.size=120 \
env.name=Pendulum-v1 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online.py \
collector.total_frames=48 \
optim.batch_size=10 \
Expand Down
79 changes: 47 additions & 32 deletions examples/discrete_sac/config.yaml
Original file line number Diff line number Diff line change
@@ -1,37 +1,52 @@
# Logger
logger: wandb
exp_name: discrete_sac
record_interval: 1
mode: online

# Environment
env_name: CartPole-v1
frame_skip: 1
from_pixels: false
reward_scaling: 1.0
init_env_steps: 1000
seed: 42
# task and env
env:
name: CartPole-v1
task: ""
exp_name: ${env.name}_DiscreteSAC
library: gym
seed: 42
max_episode_steps: 500

# Collector
env_per_collector: 1
max_frames_per_traj: 500
total_frames: 1000000
init_random_frames: 5000
frames_per_batch: 500 # 500 * env_per_collector
# collector
collector:
total_frames: 25000
init_random_frames: 1000
init_env_steps: 1000
frames_per_batch: 500
reset_at_each_iter: False
device: cuda:0
env_per_collector: 1
num_workers: 1

# Replay Buffer
prb: 0
buffer_size: 1000000
# replay buffer
replay_buffer:
prb: 0 # use prioritized experience replay
size: 1000000
scratch_dir: ${env.exp_name}_${env.seed}

# Optimization
utd_ratio: 1.0
gamma: 0.99
batch_size: 256
lr: 3.0e-4
weight_decay: 0.0
target_update_polyak: 0.995
target_entropy_weight: 0.2
# default is 0.98 but needs to be decreased for env
# with small action space
# optim
optim:
utd_ratio: 1.0
gamma: 0.99
batch_size: 256
lr: 3.0e-4
weight_decay: 0.0
target_update_polyak: 0.995
target_entropy_weight: 0.2
target_entropy: "auto"
loss_function: l2
# default is 0.98 but needs to be decreased for env
# with small action space

device: cpu
# network
network:
hidden_sizes: [256, 256]
activation: relu
device: "cuda:0"

# logging
logger:
backend: wandb
mode: online
eval_iter: 5000
Loading

0 comments on commit 6c68f7e

Please sign in to comment.