Skip to content

Commit

Permalink
bug fix for RM
Browse files Browse the repository at this point in the history
  • Loading branch information
davidzhu27 committed Sep 18, 2024
1 parent 03549f1 commit 21590e3
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 77 deletions.
26 changes: 13 additions & 13 deletions algorithms/offline/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from pbrl import scale_rewards, generate_pbrl_dataset, make_latent_reward_dataset, train_latent, predict_and_label_latent_reward
from pbrl import label_by_trajectory_reward, generate_pbrl_dataset_no_overlap, small_d4rl_dataset
from pbrl import label_by_trajectory_reward_multiple_bernoullis, label_by_original_rewards
from pbrl import label_by_original_rewards, pick_and_generate_pbrl_dataset
from ipl_helper import save_preference_dataset

TensorBatch = List[torch.Tensor]
Expand All @@ -39,6 +39,8 @@ class TrainConfig:
out_name: str = ""
quick_stop: int = 0
dataset_size_multiplier: float = 1.0
reuse_fraction: float = 0.0
reuse_times:int = 0

# Experiment
device: str = "cuda"
Expand Down Expand Up @@ -909,26 +911,24 @@ def train(config: TrainConfig):
num_t = config.num_t
len_t = config.len_t
num_trials = config.num_berno
allow_overlap=config.bin_label_allow_overlap
reuse_fraction = config.reuse_fraction
reuse_times = config.reuse_times

if config.latent_reward:
dataset = scale_rewards(dataset)
if config.bin_label_allow_overlap:
pbrl_dataset = generate_pbrl_dataset(dataset, pbrl_dataset_file_path=f'saved/pbrl_datasets/pbrl_dataset_{config.env}_{num_t}_{len_t}_numTrials={num_trials}.npz', num_t=num_t, len_t=len_t)
else:
pbrl_dataset = generate_pbrl_dataset_no_overlap(dataset, pbrl_dataset_file_path=f'saved/pbrl_datasets_no_overlap/pbrl_dataset_{config.env}_{num_t}_{len_t}_numTrials={num_trials}', num_t=num_t, len_t=len_t)
pbrl_dataset = pick_and_generate_pbrl_dataset(dataset=dataset, env = config.env, num_t=num_t, len_t=len_t, num_trials=num_trials,
allow_overlap=allow_overlap, reuse_fraction=reuse_fraction, reuse_times=reuse_times)
latent_reward_model, indices = train_latent(dataset, pbrl_dataset, num_berno=num_trials, num_t=num_t, len_t=len_t)
dataset = predict_and_label_latent_reward(dataset, latent_reward_model, indices)
elif config.bin_label:
dataset = scale_rewards(dataset)
if config.bin_label_allow_overlap:
pbrl_dataset = generate_pbrl_dataset(dataset, pbrl_dataset_file_path=f'saved/pbrl_datasets/pbrl_dataset_{config.env}_{num_t}_{len_t}_numTrials={num_trials}.npz', num_t=num_t, len_t=len_t)
else:
pbrl_dataset = generate_pbrl_dataset_no_overlap(dataset, pbrl_dataset_file_path=f'saved/pbrl_datasets_no_overlap/pbrl_dataset_{config.env}_{num_t}_{len_t}_numTrials={num_trials}', num_t=num_t, len_t=len_t)
pbrl_dataset = pick_and_generate_pbrl_dataset(dataset=dataset, env = config.env, num_t=num_t, len_t=len_t, num_trials=num_trials,
allow_overlap=allow_overlap, reuse_fraction=reuse_fraction, reuse_times=reuse_times)
dataset = label_by_trajectory_reward(dataset, pbrl_dataset, num_t=num_t, len_t=len_t, num_trials=num_trials)
else:
if config.bin_label_allow_overlap:
pbrl_dataset = generate_pbrl_dataset(dataset, pbrl_dataset_file_path=f'saved/pbrl_datasets/pbrl_dataset_{config.env}_{num_t}_{len_t}_numTrials={num_trials}.npz', num_t=num_t, len_t=len_t)
else:
pbrl_dataset = generate_pbrl_dataset_no_overlap(dataset, pbrl_dataset_file_path=f'saved/pbrl_datasets_no_overlap/pbrl_dataset_{config.env}_{num_t}_{len_t}_numTrials={num_trials}', num_t=num_t, len_t=len_t)
pbrl_dataset = pick_and_generate_pbrl_dataset(dataset=dataset, env = config.env, num_t=num_t, len_t=len_t, num_trials=num_trials,
allow_overlap=allow_overlap, reuse_fraction=reuse_fraction, reuse_times=reuse_times)
dataset = label_by_original_rewards(dataset, pbrl_dataset, num_t)
dataset = small_d4rl_dataset(dataset, dataset_size_multiplier=config.dataset_size_multiplier)
print(f'Dataset size: {(dataset["observations"]).shape[0]}')
Expand Down
28 changes: 14 additions & 14 deletions algorithms/offline/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

from pbrl import scale_rewards, generate_pbrl_dataset, make_latent_reward_dataset, train_latent, predict_and_label_latent_reward
from pbrl import label_by_trajectory_reward, generate_pbrl_dataset_no_overlap, small_d4rl_dataset
from pbrl import label_by_trajectory_reward_multiple_bernoullis, label_by_original_rewards
from ipl_helper import save_preference_dataset
from pbrl import label_by_original_rewards, pick_and_generate_pbrl_dataset
from ipl_helper import save_preference_dataset

TensorBatch = List[torch.Tensor]

Expand All @@ -45,6 +45,8 @@ class TrainConfig:
out_name: str = ""
quick_stop: int = 0
dataset_size_multiplier: float = 1.0
reuse_fraction: float = 0.0
reuse_times:int = 0

# Experiment
device: str = "cuda"
Expand Down Expand Up @@ -584,26 +586,24 @@ def train(config: TrainConfig):
num_t = config.num_t
len_t = config.len_t
num_trials = config.num_berno
allow_overlap=config.bin_label_allow_overlap
reuse_fraction = config.reuse_fraction
reuse_times = config.reuse_times

if config.latent_reward:
dataset = scale_rewards(dataset)
if config.bin_label_allow_overlap:
pbrl_dataset = generate_pbrl_dataset(dataset, pbrl_dataset_file_path=f'saved/pbrl_datasets/pbrl_dataset_{config.env}_{num_t}_{len_t}_numTrials={num_trials}.npz', num_t=num_t, len_t=len_t)
else:
pbrl_dataset = generate_pbrl_dataset_no_overlap(dataset, pbrl_dataset_file_path=f'saved/pbrl_datasets_no_overlap/pbrl_dataset_{config.env}_{num_t}_{len_t}_numTrials={num_trials}', num_t=num_t, len_t=len_t)
pbrl_dataset = pick_and_generate_pbrl_dataset(dataset=dataset, env = config.env, num_t=num_t, len_t=len_t, num_trials=num_trials,
allow_overlap=allow_overlap, reuse_fraction=reuse_fraction, reuse_times=reuse_times)
latent_reward_model, indices = train_latent(dataset, pbrl_dataset, num_berno=num_trials, num_t=num_t, len_t=len_t)
dataset = predict_and_label_latent_reward(dataset, latent_reward_model, indices)
elif config.bin_label:
dataset = scale_rewards(dataset)
if config.bin_label_allow_overlap:
pbrl_dataset = generate_pbrl_dataset(dataset, pbrl_dataset_file_path=f'saved/pbrl_datasets/pbrl_dataset_{config.env}_{num_t}_{len_t}_numTrials={num_trials}.npz', num_t=num_t, len_t=len_t)
else:
pbrl_dataset = generate_pbrl_dataset_no_overlap(dataset, pbrl_dataset_file_path=f'saved/pbrl_datasets_no_overlap/pbrl_dataset_{config.env}_{num_t}_{len_t}_numTrials={num_trials}', num_t=num_t, len_t=len_t)
pbrl_dataset = pick_and_generate_pbrl_dataset(dataset=dataset, env = config.env, num_t=num_t, len_t=len_t, num_trials=num_trials,
allow_overlap=allow_overlap, reuse_fraction=reuse_fraction, reuse_times=reuse_times)
dataset = label_by_trajectory_reward(dataset, pbrl_dataset, num_t=num_t, len_t=len_t, num_trials=num_trials)
else:
if config.bin_label_allow_overlap:
pbrl_dataset = generate_pbrl_dataset(dataset, pbrl_dataset_file_path=f'saved/pbrl_datasets/pbrl_dataset_{config.env}_{num_t}_{len_t}_numTrials={num_trials}.npz', num_t=num_t, len_t=len_t)
else:
pbrl_dataset = generate_pbrl_dataset_no_overlap(dataset, pbrl_dataset_file_path=f'saved/pbrl_datasets_no_overlap/pbrl_dataset_{config.env}_{num_t}_{len_t}_numTrials={num_trials}', num_t=num_t, len_t=len_t)
pbrl_dataset = pick_and_generate_pbrl_dataset(dataset=dataset, env = config.env, num_t=num_t, len_t=len_t, num_trials=num_trials,
allow_overlap=allow_overlap, reuse_fraction=reuse_fraction, reuse_times=reuse_times)
dataset = label_by_original_rewards(dataset, pbrl_dataset, num_t)
dataset = small_d4rl_dataset(dataset, dataset_size_multiplier=config.dataset_size_multiplier)
print(f'Dataset size: {(dataset["observations"]).shape[0]}')
Expand Down
Loading

0 comments on commit 21590e3

Please sign in to comment.