forked from prasenjit52282/AQuaMoHo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvarryDiv.py
61 lines (56 loc) · 1.97 KB
/
varryDiv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import glob
import numpy as np
import pandas as pd
from library.constants import *
from library.models.rnn import RNN
from library.models.rf import RandomForest,OldRandomForest
#-----------------------------------------------------------------
#Delhi
num_random_exp=10
dev_list=glob.glob("./Data/Delhi/*")
def get_train_test_lists(num_test_dev=1):
test_devs=np.random.choice(dev_list,num_test_dev,replace=False)
train_devs=np.array(list(set(dev_list)-set(test_devs)))
return train_devs,test_devs
#Old Random Forest
l=[]
np.random.seed(seed)
for num_test_dev in [1,2,3,4,5,6]:
for random_exp in range(num_random_exp):
train_devs,test_devs=get_train_test_lists(num_test_dev)
model=OldRandomForest()
model.train_on_file_sets(train_devs,test_devs)
met=model.metrics
met["num_test_dev"]=num_test_dev
met["rand_exp_id"]=random_exp
l.append(met)
df=pd.DataFrame(l)
df.to_csv("./logs/exp/delhi_rf_old_varryDiv.csv",index=False)
#Random Forest
l=[]
np.random.seed(seed)
for num_test_dev in [1,2,3,4,5,6]:
for random_exp in range(num_random_exp):
train_devs,test_devs=get_train_test_lists(num_test_dev)
model=RandomForest()
model.train_on_file_sets(train_devs,test_devs)
met=model.metrics
met["num_test_dev"]=num_test_dev
met["rand_exp_id"]=random_exp
l.append(met)
df=pd.DataFrame(l)
df.to_csv("./logs/exp/delhi_rf_varryDiv.csv",index=False)
#RNN
l=[]
np.random.seed(seed)
for num_test_dev in [1,2,3,4,5,6]:
for random_exp in range(num_random_exp):
train_devs,test_devs=get_train_test_lists(num_test_dev)
model=RNN(checkpoint_filepath=f'./logs/model/delhi_rnn_vd',seed=seed)
model.train_on_file_sets(train_devs,test_devs,epochs=epochs,batch_size=batch_size)
met=model.metrics
met["num_test_dev"]=num_test_dev
met["rand_exp_id"]=random_exp
l.append(met)
df=pd.DataFrame(l)
df.to_csv("./logs/exp/delhi_rnn_varryDiv.csv",index=False)