diff --git a/algorithms/offline/pbrl.py b/algorithms/offline/pbrl.py index 4a5a4ec..c9ff21a 100644 --- a/algorithms/offline/pbrl.py +++ b/algorithms/offline/pbrl.py @@ -217,16 +217,17 @@ def evaluate_latent_model(model, dataset, training_data, num_t, preferred_indice with torch.no_grad(): # training eval X_train, mu_train = training_data - latent_rewards_train = model(X_train).view(num_t, 2, len_t, -1) - latent_r_sum_train = torch.sum(latent_rewards_train, dim=2) - latent_p_train = torch.nn.functional.softmax(latent_r_sum_train, dim=1)[:,0] - latent_mu_train = torch.bernoulli(latent_p_train).long() - - mu_train_flat = mu_train.view(-1) - latent_mu_train_flat = latent_mu_train.view(-1) - assert(mu_train_flat.shape == latent_mu_train_flat.shape) - train_accuracy = accuracy_score(mu_train_flat.cpu().numpy(), latent_mu_train_flat.cpu().numpy()) - print(f'Train Accuracy: {train_accuracy:.3f}') + if torch.all(mu_train.type(torch.int64) == mu_train): + latent_rewards_train = model(X_train).view(num_t, 2, len_t, -1) + latent_r_sum_train = torch.sum(latent_rewards_train, dim=2) + latent_p_train = torch.nn.functional.softmax(latent_r_sum_train, dim=1)[:,0] + latent_mu_train = torch.bernoulli(latent_p_train).long() + + mu_train_flat = mu_train.view(-1) + latent_mu_train_flat = latent_mu_train.view(-1) + assert(mu_train_flat.shape == latent_mu_train_flat.shape) + train_accuracy = accuracy_score(mu_train_flat.cpu().numpy(), latent_mu_train_flat.cpu().numpy()) + print(f'Train Accuracy: {train_accuracy:.3f}') # testing eval t1s, t2s, ps = generate_pbrl_dataset_no_overlap(dataset, num_t=testing_num_t, len_t=len_t, save=False)