-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathserver.py
147 lines (110 loc) · 4.31 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""Send image to tensorflow serving.
Hint: the code has been compiled together with TensorFlow serving
and not locally. The client is called in the TensorFlow Docker container
"""
from __future__ import print_function
# Communication to TensorFlow server via gRPC
# from grpc.beta import implementations
import tensorflow as tf
# TensorFlow serving stuff to send messages
# from tensorflow_serving.apis import predict_pb2
# from tensorflow_serving.apis import prediction_service_pb2
import os
import requests
import cv2
import numpy as np
import sys
"""Flask container requires gPRC and tensorflow-API
to communicate with tensorflow-serving.
Tensorflow-serving-api is officially only available for python2,
although it seems to be a flagging issue in the
package manager more than anything. Currently using:
https://github.com/illagrenan/tensorflow-serving-api-python3
(see. https://github.com/tensorflow/serving/issues/700)"""
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2
import logging
import grpc
from grpc import RpcError
from flask import Flask, jsonify, request, Response
from flask_cors import CORS
from gevent.pywsgi import WSGIServer
app = Flask(__name__)
CORS(app)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.info('flask app initialized')
master_path = ('/' + os.environ['SERVICE_VERSION'] + '/' +
os.environ['SERVICE_NAME'])
def create_gprc_client(host):
"""Simple wrapper."""
channel = grpc.insecure_channel(host)
stub = prediction_service_pb2.PredictionServiceStub(channel)
request = predict_pb2.PredictRequest()
return stub, request
class PredictClientClassification():
"""Prediction Client."""
def __init__(self, host, model_name, model_version=0):
"""I."""
# super().__init__(host, model_name, model_version)
self.host = host
self.model_name = model_name
self.model_version = model_version
def predict():
"""Needed for abstract method."""
return True
def predict_mnist(self, request_data, request_timeout=10):
"""Predict."""
# Create gRPC client and request
stub, request = create_gprc_client(self.host)
request.model_spec.name = self.model_name
request.model_spec.signature_name = "predict"
# if self.model_version > 0:
# request.model_spec.version.value = self.model_version
features_tensor_proto = tf.contrib.util.make_tensor_proto(
request_data, dtype=tf.float32, shape=request_data.shape)
request.inputs['input_0'].CopyFrom(features_tensor_proto)
try:
result = stub.Predict(request, timeout=request_timeout)
return list(result.outputs['output_0'].float_val)
except RpcError as e:
print("hej", e)
@app.route(master_path + '/')
def ok():
"""Standard Return."""
return ('It\'s alive, Hello Stockholm-ai')
@app.route('/health_check')
def health_check():
"""Health Check."""
return 'healthy'
@app.route(master_path + '/predict/mnist_number', methods=['POST'])
def predict_image_position():
"""Post req."""
auth_header = request.headers.get('X-StockholmAI-Key', '')
auth_key = os.environ['PREDICTOR_SECRET_KEY']
if (auth_header != auth_key):
# logging.warning(('Someone tried to access service without key'))
return Response(
'''authorized access only. make sure you set the
X-StockholmAI-Key header correctly.''',
401)
prediction_request = request.get_json()
url = prediction_request["url"]
r = requests.get(url)
nparr = np.frombuffer(r.content, np.uint8)
img = cv2.imdecode(nparr, 0)
img = cv2.resize(img, (28, 28))
img = (np.expand_dims(img, 0) / 255.).astype(np.float32)
# Grayscale image so add channel last.
img = np.expand_dims(img, 3)
print(img.shape)
mnist_client = PredictClientClassification(
"serve_tensorflow:9001", "mnist_example_model")
output = mnist_client.predict_mnist(img)
output_class = dict(zip(list(range(0, 10)), list(output)))
return jsonify(
output_classification=output_class)
if __name__ == '__main__':
host = '0.0.0.0'
port = 8080
print("serving!")
WSGIServer((host, port), app, log=None).serve_forever()