-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
83 lines (65 loc) · 2.87 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from data import DataSet
from bandits import UCB1, EXP3
from contextual import LinUCB
from switchtask import SWTSK
from utils import clear_dirs
import pandas as pd
import numpy as np
import math
import os
import json
import argparse
'''
Args:
path to the df
num_tasks
trainig csv from DeepSpeech,
num_episodes
batch_size
batch_path (str):path to save training batch for DeepSpeech
c=0.01,
gain_type='SPG'
python main.py --df_path='tt_train_with_scores.csv' --num_tasks=3 --csv='train.csv' --num_episodes=1 --batch_size=64 --gain_type='PG' --c 0.5
python main.py --mode='EXP3' --df_path='tt_train_with_scores.csv' --num_tasks=3 --csv='train.csv' --num_episodes=1 --batch_size=64 --gain_type='PG' --c 0.1
'''
def main(args):
#Clean up checkpoint dirs
clear_dirs(args.mode, args.model_path)
df = pd.read_csv(args.df_path)
##Create dataset from csv
data = DataSet(df, args.num_tasks)
data.create_tasks()
num_timesteps = int(np.ceil(len(df)/args.batch_size))
if args.mode=='UCB1':
print('Starting UCB1...')
UCB1(dataset=data, csv=df,
num_episodes=args.num_episodes,
num_timesteps=num_timesteps,
batch_size=args.batch_size,
hist_path=args.hist_path,
c=args.c,
gain_type=args.gain_type)
elif args.mode == 'EXP3':
print('Starting EXP3...')
EXP3(data, df, args.num_episodes, num_timesteps, args.batch_size,
args.hist_path,
args.c, args.gain_type)
elif args.mode == 'LinUCB':
LinUCB(data, args.hist_path, args.num_episodes, num_timesteps, args.batch_size, args.gain_type)
elif args.mode == 'SWTSK':
s = SWTSK(data)
s.train()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--df_path', type=str, help='Path to the ranked train df', required=True)
parser.add_argument('--num_tasks', type=int, help='Number of training buckets (lvls of difficulty)', required=True)
parser.add_argument('--csv', type=str, help='Train csv from DeepSpeech filtered and preprocessed', required=True)
parser.add_argument('--num_episodes', type=int, help='Num epochs', required=True)
parser.add_argument('--batch_size', type=int, help='Batch size', required=True)
parser.add_argument('--c', type=float, help='Exploration rate')
parser.add_argument('--gain_type', type=str, help='Gain type (Prediction Gain, Self-Prediction Gain)', required=True)
parser.add_argument('--mode', type=str, help='Algorithms to run', required=True)
parser.add_argument('--hist_path', type=str, help='Path to save history files', required=True)
parser.add_argument('--model_path', type=str, help='Path to save models ', required=True)
args = parser.parse_args()
main(args)