-
Notifications
You must be signed in to change notification settings - Fork 1
/
switchtask.py
66 lines (52 loc) · 1.69 KB
/
switchtask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from data import DataSet
from utils import clear_dirs, save_batch, load_losses, train_SWTSK, run_validation, loadValLoss
import pandas as pd
import numpy as np
import math
""""
task = [1]
For each task:
1. Create a new csv with task samples.
2. Train network
3. Store losses upon early stopping
4. Append new task in taskqueue.
"""
class SWTSK:
def __init__(self, data):
self.tasks = data.tasks
self.loss_hist = []
self.val_loss_hist = []
def create_task(self, task_q):
curr_data = []
for task in task_q:
curr_data.extend(self.tasks[task])
np.random.shuffle(curr_data)
save_batch(curr_data,"switch-task-train")
def train(self):
num_tasks = len(self.tasks)
task_q = []
for task in range(num_tasks):
task_q.append(task)
self.create_task(task_q)
loss_variance = 9999
while(loss_variance > 0.025):
train_SWTSK()
loss_so_far = load_losses('SWTSK')
self.loss_hist.extend(loss_so_far)
loss_to_check = self.loss_hist[-30:]
loss_variance = np.var(loss_to_check)
run_validation('SWTSK', '../history_23/')
val_loss = loadValLoss('../history_23/')
self.val_loss_hist.append(val_loss)
np.save('val-loss-hist-switch-task.npy', np.array(self.val_loss_hist))
'''
c = 5, o = 10, cr = 0.5, 1-cr = 0.5
c = 1, o = 3, cr = 0.33, 1-cr = 0.66
cr = (o - c)/o
t(x) = length of the longest sequence --> scaling factor for reward
rewards :
0: [100,50,20,10]
1: [50,30, 10,5]
2: [10, 5, 0, 1]
rewards = avg_rew = [50.33, 26, 10, 3.33]
'''