diff --git a/submission.py b/submission.py index 806abeb..2d5c9ab 100644 --- a/submission.py +++ b/submission.py @@ -20,8 +20,6 @@ """ import os -import sys -import argparse import pandas as pd from joblib import load @@ -75,3 +73,20 @@ def predict_outcomes(df): predictions = model.predict(df) # Return the result as a Pandas DataFrame with the columns "nomem_encr" and "prediction" return pd.concat([nomem_encr, pd.Series(predictions, name="prediction")], axis=1) + + +def test_submission(df, model_path="./model.joblib"): + """Test if the code will work""" + # Load fake data + df = pd.read_csv(os.path.join(os.path.dirname(__file__), "data/fake_data.csv")) + ids = df[["nomem_encr"]] + + # Clean data as you did before the model + df = clean_df(df) + + # Load model + model = load(model_path) + + # Create prediction + ids["prediction"] = model.predict(df) + return ids