-
Notifications
You must be signed in to change notification settings - Fork 28
/
train.m
117 lines (89 loc) · 2.69 KB
/
train.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
%% Training
% Train SVM and RBT general models from original Kaggle data
% 'Name' and 'use' bugs not fixed yet
startTime = tic;
%% Set path to test data
% Set paths and prepare parameters
% Use these paths to create training set from original Kaggle data
% Original data folder, as downloaded from Kaggle
params.paths.or = 'R:\EEG Data\Original\';
% Use training data from here
params.paths.dataDir = 'R:\EEG Data\New\';
% Path to new training and test sets
params.paths.new = params.paths.dataDir;
params.paths.ModelPath = 'trainedModelsCompactTest.mat';
params.master = 61; % Version
params.nSubs = 3;
% Other params
% Edit in function
params = setParams(params);
params.plotOn = 0;
params.modParams.plotOn = false;
params.redoCopy = 1;
warning('off', 'MATLAB:table:RowsAddedExistingVars')
%% Prepare raw data
% Create new training directory from original Kaggle data as per list of
% safe files.
% Creates singles.mat needed for this set
if params.redoCopy
copyTestLeakToTrain(params.paths)
end
%% Process training set
params.tt = 'Train';
% Features to use
% Need to save this in serizureModel during training
clear use
use.hillsBandsLog2D = 0;
use.hillsBandsLogAv = 1;
use.maxHills2D = 1;
use.maxHillsAv = 1;
use.summ32D = 1;
use.summ3Av = 1;
use.bandsLin2D = 1;
use.bandsLinAv = 1;
use.maxBands2D = 1;
use.maxBandsAv = 1;
use.mCorrsT = 1;
use.mCorrsF = 1;
% Create features object for test
disp('Creating basic features')
% Epoch window sizes to use
params.divS = [240, 160, 80];
% Create object
featuresTrain = featuresObject(params, use);
% Compile available features
featuresTrain = featuresTrain.compileFeatures();
%% Run training
% Set model and cv parameters
% CV
params.cvParams.cvMode = 'Custom';
params.cvParams.k = 6;
params.cvParams.evalProp = 0.2;
params.cvParams.overSample = 0.05;
params.cvParmas.seed = 2222;
% Both models
params.modParams.keepIdx = featuresTrain.keepIdx;
params.modParams.prior = 'Empirical';
params.modParams.hyper = 0;
params.modParams.standardize = true;
params.modParams.seed = 1111;
% SVM
params.modParams.polyOrder = 2;
params.modParams.BC = 1000;
% RBT
params.modParams.nLearners = 100;
params.modParams.LearnRate = 1;
params.modParams.MaxNumSplits = 20;
% Run train function
[SVMg, RBTg] = trainModels(featuresTrain, params);
%% Assess CV AUC and save to disk
% Run compare models
disp(['SVM: general model AUC: ', num2str(SVMg.AUCScore)])
disp(['RBT: general model AUC: ', num2str(RBTg.AUCScore)])
% Save compact models to disk (params.paths.ModelPath)
SVMgCompact = SVMg.shrink();
RBTgCompact = RBTg.shrink();
save(params.paths.ModelPath, 'SVMgCompact', 'RBTgCompact')
% Report time taken
endTime = toc(startTime);
disp(['Training time taken: ', num2str(endTime), ' s'])