-
Notifications
You must be signed in to change notification settings - Fork 382
/
Copy pathuse_model.py
executable file
·106 lines (92 loc) · 3.49 KB
/
use_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#!/usr/bin/env python
# Copyright 2015 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Amazon Software License (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://aws.amazon.com/asl/
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express
# or implied. See the License for the specific language governing permissions
# and limitations under the License.
"""
Demonstrates how to use an ML Model, by setting the score threshold,
and kicks off a batch prediction job, which uses the ML Model to
generate predictions on new data. This script needs the id of the
ML Model to use. It also requires the score threshold.
Useage:
python use_model.py ml_model_id score_threshold s3_output_url
For example:
python use_model.py ml-12345678901 0.77 s3://your-bucket/prefix
"""
import base64
import boto3
import datetime
import os
import random
import sys
import time
import urlparse
# The URL of the sample data in S3
UNSCORED_DATA_S3_URL = "s3://aml-sample-data/banking-batch.csv"
def use_model(model_id, threshold, schema_fn, output_s3, data_s3url):
"""Creates all the objects needed to build an ML Model & evaluate its quality.
"""
ml = boto3.client('machinelearning')
poll_until_completed(ml, model_id) # Can't use it until it's COMPLETED
ml.update_ml_model(MLModelId=model_id, ScoreThreshold=threshold)
print("Set score threshold for %s to %.2f" % (model_id, threshold))
bp_id = 'bp-' + base64.b32encode(os.urandom(10))
ds_id = create_data_source_for_scoring(ml, data_s3url, schema_fn)
ml.create_batch_prediction(
BatchPredictionId=bp_id,
BatchPredictionName="Batch Prediction for marketing sample",
MLModelId=model_id,
BatchPredictionDataSourceId=ds_id,
OutputUri=output_s3
)
print("Created Batch Prediction %s" % bp_id)
def poll_until_completed(ml, model_id):
delay = 2
while True:
model = ml.get_ml_model(MLModelId=model_id)
status = model['Status']
message = model.get('Message', '')
now = str(datetime.datetime.now().time())
print("Model %s is %s (%s) at %s" % (model_id, status, message, now))
if status in ['COMPLETED', 'FAILED', 'INVALID']:
break
# exponential backoff with jitter
delay *= random.uniform(1.1, 2.0)
time.sleep(delay)
def create_data_source_for_scoring(ml, data_s3url, schema_fn):
ds_id = 'ds-' + base64.b32encode(os.urandom(10))
ml.create_data_source_from_s3(
DataSourceId=ds_id,
DataSourceName="DS for Batch Prediction %s" % data_s3url,
DataSpec={
"DataLocationS3": data_s3url,
"DataSchema": open(schema_fn).read(),
},
ComputeStatistics=False
)
print("Created Datasource %s for batch prediction" % ds_id)
return ds_id
if __name__ == "__main__":
try:
model_id = sys.argv[1]
threshold = float(sys.argv[2])
s3_output_url = sys.argv[3]
parsed_url = urlparse.urlparse(s3_output_url)
if parsed_url.scheme != 's3':
raise RuntimeError("s3_output_url must be an s3:// url")
except IndexError:
print(__doc__)
sys.exit(-1)
except:
print(__doc__)
raise
use_model(model_id, threshold, "banking-batch.csv.schema",
s3_output_url, UNSCORED_DATA_S3_URL)