From 66c693e80706b9942254a02bfe66e4b9ffc7a8e3 Mon Sep 17 00:00:00 2001 From: Jeroen Date: Sun, 3 Mar 2024 10:23:33 +0100 Subject: [PATCH] Added manual test function --- submission.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/submission.py b/submission.py index 806abeb2..2d5c9abb 100644 --- a/submission.py +++ b/submission.py @@ -20,8 +20,6 @@ """ import os -import sys -import argparse import pandas as pd from joblib import load @@ -75,3 +73,20 @@ def predict_outcomes(df): predictions = model.predict(df) # Return the result as a Pandas DataFrame with the columns "nomem_encr" and "prediction" return pd.concat([nomem_encr, pd.Series(predictions, name="prediction")], axis=1) + + +def test_submission(df, model_path="./model.joblib"): + """Test if the code will work""" + # Load fake data + df = pd.read_csv(os.path.join(os.path.dirname(__file__), "data/fake_data.csv")) + ids = df[["nomem_encr"]] + + # Clean data as you did before the model + df = clean_df(df) + + # Load model + model = load(model_path) + + # Create prediction + ids["prediction"] = model.predict(df) + return ids