forked from lawrennd/gp
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdemInterpolationGp.m
100 lines (93 loc) · 2.88 KB
/
demInterpolationGp.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
% DEMINTERPOLATIONGP Demonstrate Gaussian processes for interpolation.
% FORMAT
% DESC runs a simple one-D Gaussian process displaying errorbars.
%
% SEEALSO : gpCreate, demRegressionGp
%
% COPYRIGHT : Neil D. Lawrence, 2006, 2008
% GP
randn('seed', 1e6)
rand('seed', 1e6)
% Create data set
x = linspace(-1, 1, 9)';
trueKern = kernCreate(x, 'rbf');
K = kernCompute(trueKern, x);
% Sample some true function values.
yTrue = gsamp(zeros(size(x))', K, 1)';
markerSize = 20;
markerWidth = 6;
markerType = 'k.';
lineWidth = 2;
% Create a test set
indTrain{1} = [1 9]';
indTrain{2} = [1 5 9]';
indTrain{3} = [1 3 5 7 9]';
indTrain{4} = [1 2 3 4 5 6 7 8 9]';
figNo = 1;
fillColor = [0.7 0.7 0.7];
for i = 0:length(indTrain)
if i > 0
yTrain = yTrue(indTrain{i});
xTrain = x(indTrain{i});
kern = kernCreate(x, 'rbf');
% Change inverse variance (1/(lengthScale^2)))
kern.inverseWidth = 5;
xTest = linspace(-2, 2, 200)';
Kx = kernCompute(kern, xTest, xTrain);
Ktrain = kernCompute(kern, xTrain, xTrain);
yPred = Kx*pdinv(Ktrain)*yTrain;
yVar = kernDiagCompute(kern, xTest) - sum(Kx*pdinv(Ktrain).*Kx, 2);
ySd = sqrt(yVar);
figure(figNo)
clf
fill([xTest; xTest(end:-1:1)], ...
[yPred; yPred(end:-1:1)] ...
+ 2*[ySd; -ySd], ...
fillColor,'EdgeColor',fillColor)
hold on;
h = plot(xTest, yPred, 'k-');
%/~
% h = [h plot(xTest, yPred + 2*ySd, 'b--')];
% h = [h plot(xTest, yPred - 2*ySd, 'b--')];
%~/
set(h, 'linewidth', lineWidth)
p = plot(xTrain, yTrain, markerType);
set(p, 'markersize', markerSize, 'lineWidth', markerWidth);
set(gca, 'xtick', [-2 -1 0 1 2]);
set(gca, 'ytick', [-3 -2 -1 0 1 2 3]);
set(gca, 'fontname', 'times', 'fontsize', 18, 'xlim', [-2 2], 'ylim', [-3 3])
zeroAxes(gca);
if exist('printDiagram') && printDiagram
printPlot(['demInterpolation' num2str(figNo)], '../tex/diagrams', '../html');
end
figNo = figNo + 1;
else
p = [];
end
if i < length(indTrain)
figure(figNo)
if i>0
fill([xTest; xTest(end:-1:1)], ...
[yPred; yPred(end:-1:1)] ...
+ 2*[ySd; -ySd], ...
fillColor,'EdgeColor',fillColor)
hold on
h = plot(xTest, yPred, 'k-');
%/~
% h = [h plot(xTest, yPred + 2*ySd, 'b--')];
% h = [h plot(xTest, yPred - 2*ySd, 'b--')];
%~/
set(h, 'linewidth', lineWidth)
end
p = [p plot(x(indTrain{i+1}), yTrue(indTrain{i+1}), markerType)];
set(p, 'markersize', markerSize, 'linewidth', markerWidth);
set(gca, 'xtick', [-2 -1 0 1 2]);
set(gca, 'ytick', [-3 -2 -1 0 1 2 3]);
set(gca, 'fontname', 'times', 'fontsize', 18, 'xlim', [-2 2], 'ylim', [-3 3])
zeroAxes(gca);
if exist('printDiagram') && printDiagram
printPlot(['demInterpolation' num2str(figNo)], '../tex/diagrams', '../html');
end
figNo = figNo + 1;
end
end