-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathapp.py
64 lines (42 loc) · 2.05 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from flask import Flask, request, jsonify, session, url_for, redirect, render_template
import joblib
from flower_form import FlowerForm
classifier_loaded = joblib.load("saved_models/01.knn_with_iris_dataset.pkl")
encoder_loaded = joblib.load("saved_models/02.iris_label_encoder.pkl")
# prediction function
def make_prediction(model, encoder, sample_json):
# parse input from request
SepalLengthCm = sample_json['SepalLengthCm']
SepalWidthCm = sample_json['SepalWidthCm']
PetalLengthCm = sample_json['PetalLengthCm']
PetalWidthCm = sample_json['PetalWidthCm']
# Make an input vector
flower = [[SepalLengthCm, SepalWidthCm, PetalLengthCm, PetalWidthCm]]
# Predict
prediction_raw = model.predict(flower)
# Convert Species index to Species name
prediction_real = encoder.inverse_transform(prediction_raw)
return prediction_real[0]
app = Flask(__name__)
app.config['SECRET_KEY'] = 'mysecretkey'
@app.route("/", methods=['GET','POST'])
def index():
form = FlowerForm()
if form.validate_on_submit():
session['SepalLengthCm'] = form.SepalLengthCm.data
session['SepalWidthCm'] = form.SepalWidthCm.data
session['PetalLengthCm'] = form.PetalLengthCm.data
session['PetalWidthCm'] = form.PetalWidthCm.data
return redirect(url_for("prediction"))
return render_template("home.html", form=form)
# Read models
classifier_loaded = joblib.load("saved_models/01.knn_with_iris_dataset.pkl")
encoder_loaded = joblib.load("saved_models/02.iris_label_encoder.pkl")
@app.route('/prediction')
def prediction():
content = {'SepalLengthCm': float(session['SepalLengthCm']), 'SepalWidthCm': float(session['SepalWidthCm']),
'PetalLengthCm': float(session['PetalLengthCm']), 'PetalWidthCm': float(session['PetalWidthCm'])}
results = make_prediction(classifier_loaded, encoder_loaded, content)
return render_template('prediction.html', results=results)
if __name__ == '__main__':
app.run(host='0.0.0.0', port=8080)