Skip to content

Commit

Permalink
fix demo exm (#57)
Browse files Browse the repository at this point in the history
Signed-off-by: youliangtan <[email protected]>
  • Loading branch information
youliangtan authored Jun 7, 2024
1 parent 7d000b2 commit 20e27b4
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 90 deletions.
24 changes: 10 additions & 14 deletions examples/async_bin_relocation_fwbw_drq/async_drq_randomized.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,17 @@ def update_params_bw(params):
##############################################################################


def learner(rng, agent: DrQAgent, replay_buffer, demo_buffer, wandb_logger=None):
def learner(rng, agent: DrQAgent, replay_buffer, demo_buffer):
"""
The learner loop, which runs when "--learner" is set to True.
"""
# set up wandb and logging
wandb_logger = make_wandb_logger(
project="serl_dev",
description=FLAGS.exp_name or FLAGS.env,
debug=FLAGS.debug,
)

# To track the step in the training loop
update_steps = 0
env_steps = 0
Expand Down Expand Up @@ -496,24 +503,14 @@ def main(_):
jax.tree_map(jnp.array, agent), sharding.replicate()
)

def create_replay_buffer_and_wandb_logger():
if FLAGS.learner:
sampling_rng = jax.device_put(sampling_rng, device=sharding.replicate())
replay_buffer = MemoryEfficientReplayBufferDataStore(
env.observation_space,
env.action_space,
capacity=FLAGS.replay_buffer_capacity,
image_keys=image_keys,
)
# set up wandb and logging
wandb_logger = make_wandb_logger(
project="serl_dev_fwbw",
description=FLAGS.exp_name or FLAGS.env,
debug=FLAGS.debug,
)
return replay_buffer, wandb_logger

if FLAGS.learner:
sampling_rng = jax.device_put(sampling_rng, device=sharding.replicate())
replay_buffer, wandb_logger = create_replay_buffer_and_wandb_logger()
demo_buffer = MemoryEfficientReplayBufferDataStore(
env.observation_space,
env.action_space,
Expand All @@ -535,7 +532,6 @@ def create_replay_buffer_and_wandb_logger():
agent,
replay_buffer,
demo_buffer=demo_buffer,
wandb_logger=wandb_logger,
)

elif FLAGS.actor:
Expand Down
24 changes: 10 additions & 14 deletions examples/async_cable_route_drq/async_drq_randomized.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,17 @@ def update_params(params):
##############################################################################


def learner(rng, agent: DrQAgent, replay_buffer, demo_buffer, wandb_logger=None):
def learner(rng, agent: DrQAgent, replay_buffer, demo_buffer):
"""
The learner loop, which runs when "--learner" is set to True.
"""
# set up wandb and logging
wandb_logger = make_wandb_logger(
project="serl_dev",
description=FLAGS.exp_name or FLAGS.env,
debug=FLAGS.debug,
)

# To track the step in the training loop
update_steps = 0

Expand Down Expand Up @@ -366,24 +373,14 @@ def main(_):
jax.tree_map(jnp.array, agent), sharding.replicate()
)

def create_replay_buffer_and_wandb_logger():
if FLAGS.learner:
sampling_rng = jax.device_put(sampling_rng, device=sharding.replicate())
replay_buffer = MemoryEfficientReplayBufferDataStore(
env.observation_space,
env.action_space,
capacity=FLAGS.replay_buffer_capacity,
image_keys=image_keys,
)
# set up wandb and logging
wandb_logger = make_wandb_logger(
project="serl_dev",
description=FLAGS.exp_name or FLAGS.env,
debug=FLAGS.debug,
)
return replay_buffer, wandb_logger

if FLAGS.learner:
sampling_rng = jax.device_put(sampling_rng, device=sharding.replicate())
replay_buffer, wandb_logger = create_replay_buffer_and_wandb_logger()
demo_buffer = MemoryEfficientReplayBufferDataStore(
env.observation_space,
env.action_space,
Expand All @@ -409,7 +406,6 @@ def create_replay_buffer_and_wandb_logger():
agent,
replay_buffer,
demo_buffer=demo_buffer,
wandb_logger=wandb_logger,
)

elif FLAGS.actor:
Expand Down
32 changes: 13 additions & 19 deletions examples/async_drq_sim/async_drq_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,17 @@ def learner(
agent: DrQAgent,
replay_buffer: MemoryEfficientReplayBufferDataStore,
demo_buffer: Optional[MemoryEfficientReplayBufferDataStore] = None,
wandb_logger=None,
):
"""
The learner loop, which runs when "--learner" is set to True.
"""
# set up wandb and logging
wandb_logger = make_wandb_logger(
project="serl_dev",
description=FLAGS.exp_name or FLAGS.env,
debug=FLAGS.debug,
)

# To track the step in the training loop
update_steps = 0

Expand Down Expand Up @@ -333,7 +339,8 @@ def main(_):
jax.tree_map(jnp.array, agent), sharding.replicate()
)

def create_replay_buffer_and_wandb_logger():
if FLAGS.learner:
sampling_rng = jax.device_put(sampling_rng, device=sharding.replicate())
replay_buffer = make_replay_buffer(
env,
capacity=FLAGS.replay_buffer_capacity,
Expand All @@ -342,18 +349,6 @@ def create_replay_buffer_and_wandb_logger():
image_keys=image_keys,
)

# set up wandb and logging
wandb_logger = make_wandb_logger(
project="serl_dev",
description=FLAGS.exp_name or FLAGS.env,
debug=FLAGS.debug,
)
return replay_buffer, wandb_logger

if FLAGS.learner:
sampling_rng = jax.device_put(sampling_rng, device=sharding.replicate())
replay_buffer, wandb_logger = create_replay_buffer_and_wandb_logger()

print_green("replay buffer created")
print_green(f"replay_buffer size: {len(replay_buffer)}")

Expand All @@ -368,10 +363,10 @@ def preload_data_transform(data, metadata) -> Optional[Dict[str, Any]]:
# This default does nothing
return data

demo_buffer = MemoryEfficientReplayBufferDataStore(
env.observation_space,
env.action_space,
capacity=10000,
demo_buffer = make_replay_buffer(
env,
capacity=FLAGS.replay_buffer_capacity,
type="memory_efficient_replay_buffer",
image_keys=image_keys,
preload_rlds_path=FLAGS.preload_rlds_path,
preload_data_transform=preload_data_transform,
Expand All @@ -398,7 +393,6 @@ def preload_data_transform(data, metadata) -> Optional[Dict[str, Any]]:
agent,
replay_buffer,
demo_buffer=demo_buffer, # None if no demo data is provided
wandb_logger=wandb_logger,
)

elif FLAGS.actor:
Expand Down
24 changes: 10 additions & 14 deletions examples/async_pcb_insert_drq/async_drq_randomized.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,17 @@ def update_params(params):
##############################################################################


def learner(rng, agent: DrQAgent, replay_buffer, demo_buffer, wandb_logger=None):
def learner(rng, agent: DrQAgent, replay_buffer, demo_buffer):
"""
The learner loop, which runs when "--learner" is set to True.
"""
# set up wandb and logging
wandb_logger = make_wandb_logger(
project="serl_dev",
description=FLAGS.exp_name or FLAGS.env,
debug=FLAGS.debug,
)

# To track the step in the training loop
update_steps = 0
global PAUSE_EVENT_FLAG
Expand Down Expand Up @@ -436,24 +443,14 @@ def main(_):
jax.tree_map(jnp.array, agent), sharding.replicate()
)

def create_replay_buffer_and_wandb_logger():
if FLAGS.learner:
sampling_rng = jax.device_put(sampling_rng, device=sharding.replicate())
replay_buffer = MemoryEfficientReplayBufferDataStore(
env.observation_space,
env.action_space,
capacity=FLAGS.replay_buffer_capacity,
image_keys=image_keys,
)
# set up wandb and logging
wandb_logger = make_wandb_logger(
project="serl_dev",
description=FLAGS.exp_name or FLAGS.env,
debug=FLAGS.debug,
)
return replay_buffer, wandb_logger

if FLAGS.learner:
sampling_rng = jax.device_put(sampling_rng, device=sharding.replicate())
replay_buffer, wandb_logger = create_replay_buffer_and_wandb_logger()
demo_buffer = MemoryEfficientReplayBufferDataStore(
env.observation_space,
env.action_space,
Expand All @@ -475,7 +472,6 @@ def create_replay_buffer_and_wandb_logger():
agent,
replay_buffer,
demo_buffer=demo_buffer,
wandb_logger=wandb_logger,
)

elif FLAGS.actor:
Expand Down
24 changes: 10 additions & 14 deletions examples/async_peg_insert_drq/async_drq_randomized.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,17 @@ def update_params(params):
##############################################################################


def learner(rng, agent: DrQAgent, replay_buffer, demo_buffer, wandb_logger=None):
def learner(rng, agent: DrQAgent, replay_buffer, demo_buffer):
"""
The learner loop, which runs when "--learner" is set to True.
"""
# set up wandb and logging
wandb_logger = make_wandb_logger(
project="serl_dev",
description=FLAGS.exp_name or FLAGS.env,
debug=FLAGS.debug,
)

# To track the step in the training loop
update_steps = 0

Expand Down Expand Up @@ -344,24 +351,14 @@ def main(_):
jax.tree_map(jnp.array, agent), sharding.replicate()
)

def create_replay_buffer_and_wandb_logger():
if FLAGS.learner:
sampling_rng = jax.device_put(sampling_rng, device=sharding.replicate())
replay_buffer = MemoryEfficientReplayBufferDataStore(
env.observation_space,
env.action_space,
capacity=FLAGS.replay_buffer_capacity,
image_keys=image_keys,
)
# set up wandb and logging
wandb_logger = make_wandb_logger(
project="serl_dev",
description=FLAGS.exp_name or FLAGS.env,
debug=FLAGS.debug,
)
return replay_buffer, wandb_logger

if FLAGS.learner:
sampling_rng = jax.device_put(sampling_rng, device=sharding.replicate())
replay_buffer, wandb_logger = create_replay_buffer_and_wandb_logger()
demo_buffer = MemoryEfficientReplayBufferDataStore(
env.observation_space,
env.action_space,
Expand All @@ -383,7 +380,6 @@ def create_replay_buffer_and_wandb_logger():
agent,
replay_buffer,
demo_buffer=demo_buffer,
wandb_logger=wandb_logger,
)

elif FLAGS.actor:
Expand Down
25 changes: 10 additions & 15 deletions examples/async_sac_state_sim/async_sac_state_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,17 @@ def update_params(params):
##############################################################################


def learner(rng, agent: SACAgent, replay_buffer, replay_iterator, wandb_logger=None):
def learner(rng, agent: SACAgent, replay_buffer, replay_iterator):
"""
The learner loop, which runs when "--learner" is set to True.
"""
# set up wandb and logging
wandb_logger = make_wandb_logger(
project="serl_dev",
description=FLAGS.exp_name or FLAGS.env,
debug=FLAGS.debug,
)

# To track the step in the training loop
update_steps = 0

Expand Down Expand Up @@ -266,26 +273,15 @@ def main(_):
jax.tree_map(jnp.array, agent), sharding.replicate()
)

def create_replay_buffer_and_wandb_logger():
if FLAGS.learner:
sampling_rng = jax.device_put(sampling_rng, device=sharding.replicate())
replay_buffer = make_replay_buffer(
env,
capacity=FLAGS.replay_buffer_capacity,
rlds_logger_path=FLAGS.log_rlds_path,
type="replay_buffer",
preload_rlds_path=FLAGS.preload_rlds_path,
)

# set up wandb and logging
wandb_logger = make_wandb_logger(
project="serl_dev",
description=FLAGS.exp_name or FLAGS.env,
debug=FLAGS.debug,
)
return replay_buffer, wandb_logger

if FLAGS.learner:
sampling_rng = jax.device_put(sampling_rng, device=sharding.replicate())
replay_buffer, wandb_logger = create_replay_buffer_and_wandb_logger()
replay_iterator = replay_buffer.get_iterator(
sample_args={
"batch_size": FLAGS.batch_size * FLAGS.utd_ratio,
Expand All @@ -299,7 +295,6 @@ def create_replay_buffer_and_wandb_logger():
agent,
replay_buffer,
replay_iterator=replay_iterator,
wandb_logger=wandb_logger,
)

elif FLAGS.actor:
Expand Down

0 comments on commit 20e27b4

Please sign in to comment.