-
Notifications
You must be signed in to change notification settings - Fork 14
/
example.m
44 lines (36 loc) · 1.08 KB
/
example.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
%% Load paths
addpath(genpath('.'));
%% Load data
load mnist_uint8;
% Convert data and rescale between 0 and 0.2
train_x = double(train_x) / 255 * 0.2;
test_x = double(test_x) / 255 * 0.2;
train_y = double(train_y) * 0.2;
test_y = double(test_y) * 0.2;
%% Train network
% Setup
rand('seed', 42);
clear edbn opts;
edbn.sizes = [784 100 10];
opts.numepochs = 6;
[edbn, opts] = edbnsetup(edbn, opts);
% Train
fprintf('Beginning training.\n');
edbn = edbntrain(edbn, train_x, opts);
% Use supervised training on the top layer
edbn = edbntoptrain(edbn, train_x, opts, train_y);
% Show results
figure;
visualize(edbn.erbm{1}.W'); % Visualize the RBM weights
er = edbntest (edbn, test_x, test_y);
fprintf('Scored: %2.2f\n', (1-er)*100);
%% Show the EDBN in action
spike_list = live_edbn(edbn, test_x(1, :), opts);
output_idxs = (spike_list.layers == numel(edbn.sizes));
figure(2); clf;
hist(spike_list.addrs(output_idxs) - 1, 0:edbn.sizes(end));
xlabel('Digit Guessed');
ylabel('Histogram Spike Count');
title('Label Layer Classification Spikes');
%% Export to xml
edbntoxml(edbn, opts, 'mnist_edbn');