-
Notifications
You must be signed in to change notification settings - Fork 9
/
tfw_p3d.m
94 lines (71 loc) · 2.31 KB
/
tfw_p3d.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
classdef tfw_p3d < tfw_i
%TFW_P3D pseudo 3D ConvNet for segmentation
% Taking volume patch as input, outputing the foreground score at the
% center.
properties
end
methods
function ob = tfw_p3d()
% Initialize the DAG net connection
%%% set the connection structure
% -- layer 0
% the value normalization by mean and std
ell = 1;
tfs{ell} = tf_norm_ms();
% -- layer I
ell = ell + 1;
tfs{ell} = tfw_ConvReluPoolDrop();
tfs{ell}.i = tfs{ell-1}.o;
% -- layer II
ell = ell + 1;
tfs{ell} = tfw_ConvReluPoolDrop();
tfs{ell}.i = tfs{ell-1}.o;
% -- layer III
ell = ell + 1;
tfs{ell} = tfw_ConvReluPoolDrop();
tfs{ell}.i = tfs{ell-1}.o;
% -- layer IV, output
ell = ell + 1;
tfs{ell} = tf_conv();
tfs{ell}.i = tfs{ell-1}.o;
% loss
ell = ell + 1;
tfs{ell} = tf_loss_lse();
tfs{ell}.i(1) = tfs{ell-1}.o;
% write back
ob.tfs = tfs;
%%% input/output data
ob.i = [n_data(), n_data()]; % X_bat, Y_bat, respectively
ob.o = n_data(); % the loss
%%% associate the parameters
ob.p = dag_util.collect_params( ob.tfs );
end % tfw_lenetDropout
function ob = fprop(ob)
%%% Outer Input --> Internal Input
ob.tfs{1}.i.a = ob.ab.cvt_data( ob.i(1).a ); % bat_X
ob.tfs{end}.i(2).a = ob.ab.cvt_data( ob.i(2).a ); % bat_Y
%%% fprop for all
for i = 1 : numel( ob.tfs )
ob.tfs{i} = fprop(ob.tfs{i});
ob.ab.sync();
end
%%% Internal Output --> Outer Output: set the loss
ob.o.a = ob.tfs{end}.o.a;
end % fprop
function ob = bprop(ob)
%%% Outer output --> Internal output: unnecessary here
%%% bprop for all
for i = numel(ob.tfs) : -1 : 2
ob.tfs{i} = bprop(ob.tfs{i});
ob.ab.sync();
end
% tfs{1} is the normalization tf, skip it
% %%% Internal Input --> Outer Input: just the input 1, the image
% ob.i(1).d = ob.tfs{1}.i.d; % bat_X
end % bprop
% helper
function Ypre = get_Ypre(ob)
Ypre = ob.tfs{end-1}.o.a;
end % get_Ypre
end % methods
end % tfw_p3d