Skip to content

Commit

Permalink
add checkpoint path flag to train classifier script
Browse files Browse the repository at this point in the history
  • Loading branch information
Leo428 committed Jan 28, 2024
1 parent 037d6a3 commit 56473c0
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
export XLA_PYTHON_CLIENT_PREALLOCATE=false && \
export XLA_PYTHON_CLIENT_MEM_FRACTION=.2 && \
python train_reward_classifier.py "$@" \
--classifier_name bw \
--classifier_ckpt_path /home/undergrad/code/serl_dev/examples/async_bin_relocation_fwbw_drq/bw_classifier_ckpt \
--positive_demo_paths ./classifier_data/bw_bin_relocate_338_front_cam_goal_2024-01-23_15-06-18.pkl \
--positive_demo_paths ./classifier_data/bw_bin_relocate_400_front_cam_goal_2024-01-23_15-12-49.pkl \
--negative_demo_paths ./classifier_data/bw_bin_relocate_475_front_cam_failed_2024-01-23_15-12-49.pkl \
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
export XLA_PYTHON_CLIENT_PREALLOCATE=false && \
export XLA_PYTHON_CLIENT_MEM_FRACTION=.2 && \
python train_reward_classifier.py "$@" \
--classifier_name fw \
--classifier_ckpt_path /home/undergrad/code/serl_dev/examples/async_bin_relocation_fwbw_drq/fw_classifier_ckpt \
--positive_demo_paths ./classifier_data/fw_bin_relocate_400_front_cam_goal_2024-01-23_15-06-18.pkl \
--positive_demo_paths ./classifier_data/fw_bin_relocate_400_front_cam_goal_2024-01-23_15-12-49.pkl \
--negative_demo_paths ./classifier_data/fw_bin_relocate_486_front_cam_failed_2024-01-23_15-12-49.pkl \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
FLAGS = flags.FLAGS
flags.DEFINE_multi_string("positive_demo_paths", None, "paths to positive demos")
flags.DEFINE_multi_string("negative_demo_paths", None, "paths to negative demos")
flags.DEFINE_string("classifier_name", "fw", "Name of classifier: fw or bw")
flags.DEFINE_string("classifier_ckpt_path", None, "Path to classifier checkpoint")


def main(_):
Expand Down Expand Up @@ -148,7 +148,7 @@ def loss_fn(params):
)

checkpoints.save_checkpoint(
f"/home/undergrad/code/serl_dev/examples/async_bin_relocation_fwbw_drq/{FLAGS.classifier_name}_classifier_ckpt",
FLAGS.classifier_ckpt_path,
classifier,
step=num_epochs,
overwrite=True,
Expand Down

0 comments on commit 56473c0

Please sign in to comment.