Skip to content

Commit

Permalink
bug fixing for eval_latent for multi berno
Browse files Browse the repository at this point in the history
  • Loading branch information
davidzhu27 committed Sep 27, 2024
1 parent 76041a8 commit 53c50db
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions algorithms/offline/pbrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,16 +217,17 @@ def evaluate_latent_model(model, dataset, training_data, num_t, preferred_indice
with torch.no_grad():
# training eval
X_train, mu_train = training_data
latent_rewards_train = model(X_train).view(num_t, 2, len_t, -1)
latent_r_sum_train = torch.sum(latent_rewards_train, dim=2)
latent_p_train = torch.nn.functional.softmax(latent_r_sum_train, dim=1)[:,0]
latent_mu_train = torch.bernoulli(latent_p_train).long()

mu_train_flat = mu_train.view(-1)
latent_mu_train_flat = latent_mu_train.view(-1)
assert(mu_train_flat.shape == latent_mu_train_flat.shape)
train_accuracy = accuracy_score(mu_train_flat.cpu().numpy(), latent_mu_train_flat.cpu().numpy())
print(f'Train Accuracy: {train_accuracy:.3f}')
if torch.all(mu_train.type(torch.int64) == mu_train):
latent_rewards_train = model(X_train).view(num_t, 2, len_t, -1)
latent_r_sum_train = torch.sum(latent_rewards_train, dim=2)
latent_p_train = torch.nn.functional.softmax(latent_r_sum_train, dim=1)[:,0]
latent_mu_train = torch.bernoulli(latent_p_train).long()

mu_train_flat = mu_train.view(-1)
latent_mu_train_flat = latent_mu_train.view(-1)
assert(mu_train_flat.shape == latent_mu_train_flat.shape)
train_accuracy = accuracy_score(mu_train_flat.cpu().numpy(), latent_mu_train_flat.cpu().numpy())
print(f'Train Accuracy: {train_accuracy:.3f}')

# testing eval
t1s, t2s, ps = generate_pbrl_dataset_no_overlap(dataset, num_t=testing_num_t, len_t=len_t, save=False)
Expand Down

0 comments on commit 53c50db

Please sign in to comment.