From 53c50dbbd16c0932210deffefd7c37c0c0dd9259 Mon Sep 17 00:00:00 2001 From: yaboidav3 Date: Thu, 26 Sep 2024 19:55:44 -0500 Subject: [PATCH] bug fixing for eval_latent for multi berno --- algorithms/offline/pbrl.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/algorithms/offline/pbrl.py b/algorithms/offline/pbrl.py index 4a5a4ec..c9ff21a 100644 --- a/algorithms/offline/pbrl.py +++ b/algorithms/offline/pbrl.py @@ -217,16 +217,17 @@ def evaluate_latent_model(model, dataset, training_data, num_t, preferred_indice with torch.no_grad(): # training eval X_train, mu_train = training_data - latent_rewards_train = model(X_train).view(num_t, 2, len_t, -1) - latent_r_sum_train = torch.sum(latent_rewards_train, dim=2) - latent_p_train = torch.nn.functional.softmax(latent_r_sum_train, dim=1)[:,0] - latent_mu_train = torch.bernoulli(latent_p_train).long() - - mu_train_flat = mu_train.view(-1) - latent_mu_train_flat = latent_mu_train.view(-1) - assert(mu_train_flat.shape == latent_mu_train_flat.shape) - train_accuracy = accuracy_score(mu_train_flat.cpu().numpy(), latent_mu_train_flat.cpu().numpy()) - print(f'Train Accuracy: {train_accuracy:.3f}') + if torch.all(mu_train.type(torch.int64) == mu_train): + latent_rewards_train = model(X_train).view(num_t, 2, len_t, -1) + latent_r_sum_train = torch.sum(latent_rewards_train, dim=2) + latent_p_train = torch.nn.functional.softmax(latent_r_sum_train, dim=1)[:,0] + latent_mu_train = torch.bernoulli(latent_p_train).long() + + mu_train_flat = mu_train.view(-1) + latent_mu_train_flat = latent_mu_train.view(-1) + assert(mu_train_flat.shape == latent_mu_train_flat.shape) + train_accuracy = accuracy_score(mu_train_flat.cpu().numpy(), latent_mu_train_flat.cpu().numpy()) + print(f'Train Accuracy: {train_accuracy:.3f}') # testing eval t1s, t2s, ps = generate_pbrl_dataset_no_overlap(dataset, num_t=testing_num_t, len_t=len_t, save=False)