Skip to content

Commit

Permalink
feat: Add dependencies and update code for AutoGluon model
Browse files Browse the repository at this point in the history
  • Loading branch information
SverreNystad committed Mar 29, 2024
1 parent 0a825ce commit d6e28b8
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions models/auto_gluon.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"outputs": [],
"source": [
"%pip install autogluon\n",
"%pip install bokeh\n",
"%pip install scikit-learn"
]
},
Expand All @@ -61,15 +62,15 @@
"metadata": {},
"outputs": [],
"source": [
"%autoreload \n",
"%autoreload 2\n",
"\n",
"from sklearn.metrics import accuracy_score, classification_report\n",
"import autogluon.core as ag\n",
"import pandas as pd\n",
"from autogluon.tabular import TabularDataset, TabularPredictor\n",
"\n",
"from src.features.post_processor import save_predictions\n",
"from src.ml_service import prepare_data, prepare_test_data\n",
"from src.config import TARGET_FEATURES"
"from src.ml_service import prepare_data, prepare_test_data, save_predictions\n",
"from src.config import TARGET_FEATURE"
]
},
{
Expand All @@ -86,10 +87,11 @@
"outputs": [],
"source": [
"x_train, _, x_test, y_train, _, y_test = prepare_data(validation_size=0, test_size=0.1)\n",
"for target_feature_name in TARGET_FEATURES:\n",
" x_train[target_feature_name] = y_train\n",
"\n",
"data = TabularDataset(x_train)"
"combined_train_data = pd.concat([x_train, y_train], axis=1)\n",
"combined_test_data = pd.concat([x_test, y_test], axis=1)\n",
"training_data = TabularDataset(combined_train_data)\n",
"test_data = TabularDataset(combined_test_data)"
]
},
{
Expand Down Expand Up @@ -122,9 +124,9 @@
"outputs": [],
"source": [
"# Initialize the AutoGluon TabularPredictor\n",
"time_limit = 24*60*60 # Set this to longest time you are willing to wait (in seconds)\n",
"metric = 'roc_auc'\n",
"predictor = TabularPredictor(label=target_feature_name, eval_metric=metric).fit(data, time_limit=time_limit, presets='best_quality')"
"time_limit = 3*60 #24*60*60 # Set this to longest time you are willing to wait (in seconds)\n",
"metric = 'log_loss'\n",
"predictor = TabularPredictor(label=TARGET_FEATURE, eval_metric=metric).fit(training_data, time_limit=time_limit, presets='best_quality')"
]
},
{
Expand Down Expand Up @@ -166,7 +168,7 @@
"test_accuracy = accuracy_score(y_test, y_test_pred)\n",
"print(\"Test Accuracy: \", test_accuracy)\n",
"print(\"Test Classification Report:\\n\", classification_report(y_test, y_test_pred))\n",
"# predictor.leaderboard(x_test, silent=True)\n"
"predictor.leaderboard(test_data)\n"
]
},
{
Expand All @@ -182,8 +184,8 @@
"metadata": {},
"outputs": [],
"source": [
"x_test = prepare_test_data()\n",
"final_predictions = predictor.predict(x_test)"
"x_final_test = prepare_test_data()\n",
"final_predictions = predictor.predict(x_final_test)"
]
},
{
Expand Down

0 comments on commit d6e28b8

Please sign in to comment.