Skip to content

Commit

Permalink
unbounded
Browse files Browse the repository at this point in the history
  • Loading branch information
lytex committed Sep 4, 2024
1 parent f23db61 commit 9cf2ffd
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions new_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,25 +516,25 @@ def main(sigma = 20, sigma_upper = 5,

lightcurves = [lc for lc in lightcurves if lc is not None]

model_1_lazy = lambda : get_model_wrapper(lightcurves, use_wavelet=use_wavelet, binary_classification=binary_classification, frac=frac, test_size=test_size,
global_level_list=global_level_list, local_level_list=local_level_list,
l1=l1, l2=l2, dropout=dropout,
num_bins_global=num_bins_global,
num_bins_local=num_bins_local)
model_1_lazy = lambda : get_model_wrapper(lightcurves, use_wavelet=use_wavelet, binary_classification=binary_classification, frac=frac, test_size=test_size,
global_level_list=global_level_list, local_level_list=local_level_list,
l1=l1, l2=l2, dropout=dropout,
num_bins_global=num_bins_global,
num_bins_local=num_bins_local)

if k_fold is None:
model_1, history_1, num2class, X_val, y_val, X_test, y_test, recall_val = train_model(model_1_lazy, lightcurves,
use_wavelet=use_wavelet, binary_classification=binary_classification,
k_fold=k_fold, global_level_list=global_level_list, local_level_list=local_level_list, epochs=epochs, batch_size=batch_size, test_size=test_size)

if k_fold is None:
model_1, history_1, num2class, X_val, y_val, X_test, y_test, recall_val = train_model(model_1_lazy, lightcurves,
use_wavelet=use_wavelet, binary_classification=binary_classification,
k_fold=k_fold, global_level_list=global_level_list, local_level_list=local_level_list, epochs=epochs, batch_size=batch_size, test_size=test_size)


precision_val, recall_val, F1_val, Fβ_val, auc_val, cm_val, num2class = get_metrics(num2class, X_val, y_val, model_1, β=β, binary_classification=binary_classification, plot=True)
else:
# TODO añadir en el caso de k-fold
precision_val, recall_val, F1_val, Fβ_val, auc_val, cm_val, num2class = get_metrics(num2class, X_val, y_val, model_1, β=β, binary_classification=binary_classification, plot=True)
precision_val, recall_val, F1_val, Fβ_val, auc_val, cm_val, num2class = get_metrics(num2class, X_val, y_val, model_1, β=β, binary_classification=binary_classification, plot=True)
else:
# TODO añadir en el caso de k-fold
precision_val, recall_val, F1_val, Fβ_val, auc_val, cm_val, num2class = get_metrics(num2class, X_val, y_val, model_1, β=β, binary_classification=binary_classification, plot=True)

precision, recall, F1, , auc, cm, num2class = get_metrics(num2class, X_test, y_test, model_1, β=β, binary_classification=binary_classification, plot=True)

precision, recall, F1, , auc, cm, num2class = get_metrics(num2class, X_test, y_test, model_1, β=β, binary_classification=binary_classification, plot=True)



if apply_candidates:
Expand Down

0 comments on commit 9cf2ffd

Please sign in to comment.