-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
76 lines (58 loc) · 2.19 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import storage
import model
import os
import shutil
import reg_blocker
import random
class BaoTrainingException(Exception):
pass
def train_and_swap(fn, old, tmp, verbose=False):
if os.path.exists(fn):
old_model = model.BaoRegression(have_cache_data=True)
old_model.load(fn)
else:
old_model = None
new_model = train_and_save_model(tmp, verbose=verbose)
max_retries = 5
current_retry = 1
while not reg_blocker.should_replace_model(old_model, new_model):
if current_retry >= max_retries == 0:
print("Could not train model with better regression profile.")
return
print("New model rejected when compared with old model. "
+ "Trying to retrain with emphasis on regressions.")
print("Retry #", current_retry)
new_model = train_and_save_model(tmp, verbose=verbose,
emphasize_experiments=current_retry)
current_retry += 1
if os.path.exists(fn):
shutil.rmtree(old, ignore_errors=True)
os.rename(fn, old)
os.rename(tmp, fn)
def train_and_save_model(fn, verbose=True, emphasize_experiments=0):
all_experience = storage.experience()
for _ in range(emphasize_experiments):
all_experience.extend(storage.experiment_experience())
# if len(all_experience1) <= 2000:
# all_experience = all_experience1
# else:
# all_experience = random.sample(all_experience1, 2000)
x = [i[0] for i in all_experience]
y = [i[1] for i in all_experience]
if not all_experience:
raise BaoTrainingException("Cannot train a Bao model with no experience")
if len(all_experience) < 20:
print("Warning: trying to train a Bao model with fewer than 20 datapoints.")
reg = model.BaoRegression(have_cache_data=True, verbose=verbose)
reg.fit(x, y)
reg.save(fn)
return reg
if __name__ == "__main__":
import sys
if len(sys.argv) != 2:
print("Usage: train.py MODEL_FILE")
exit(-1)
train_and_save_model(sys.argv[1])
print("Model saved, attempting load...")
reg = model.BaoRegression(have_cache_data=True)
reg.load(sys.argv[1])