Skip to content

Commit

Permalink
Merge pull request #7 from alexander-hn/model-predict
Browse files Browse the repository at this point in the history
Model: add prediction function
  • Loading branch information
TomasPhilippart authored Nov 27, 2024
2 parents ab6d491 + 1595353 commit d5e7d3c
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions machinelearning/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,42 @@ def test(file: str):
roc_display.plot(ax=ax2)
plt.savefig(FILES['model']['results'])

def pred():
X_test = pd.read_csv(FILES['testing']['data'])
#y_test = get_labels(X_test, file)
pidarr = X_test['PID']

# get rid of PID column, dont use for training
X_test.drop(columns=['PID'], inplace=True)

# make sure to use the same features as for training
training_features = load(FILES['model']['features'])
test_features = X_test.columns
# drop new features
X_test.drop(columns=[f for f in test_features if f not in training_features], inplace=True)
# add missing features
for f in training_features:
if f not in test_features:
X_test[f] = 0

# scale the test data
scaler = load(FILES['model']['scaler'])
scaler.transform(X_test)

# predict with the previously trained classifier
classifier = load(FILES['model']['model'])

prediction = pd.DataFrame({'PID': pidarr, 'Pred': classifier.predict(X_test)})
#pd.set_option('display.max_rows', len(prediction))
print(prediction.to_string())


def main():
parser = argparse.ArgumentParser()
parser.add_argument('--train', action='store_true', help='Train a new model')
parser.add_argument('--test', action='store_true', help='Test an existing model')
parser.add_argument('--labels', default='file', choices=['file','data'], help='Read labels from file or data')
parser.add_argument('--predict', action='store_true', help='Predict using an existing model')
args = parser.parse_args()

if args.train:
Expand All @@ -171,6 +201,9 @@ def main():
else:
test(file=None)

if args.predict:
pred()


if __name__ == '__main__':
main()

0 comments on commit d5e7d3c

Please sign in to comment.