-
Notifications
You must be signed in to change notification settings - Fork 1
/
gp.jl
201 lines (167 loc) · 5.12 KB
/
gp.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
### A Pluto.jl notebook ###
# v0.12.20
using Markdown
using InteractiveUtils
# This Pluto notebook uses @bind for interactivity. When running this notebook outside of Pluto, the following 'mock version' of @bind gives bound variables a default value (instead of an error).
macro bind(def, element)
quote
local el = $(esc(element))
global $(esc(def)) = Core.applicable(Base.get, el) ? Base.get(el) : missing
el
end
end
# ╔═╡ c823a222-6e1c-11eb-18a8-01a1b673d7b0
begin
import Pkg;
Pkg.add("LinearAlgebra")
Pkg.add("Random")
Pkg.add("PlutoUI")
Pkg.add("Plots")
Pkg.add("Plotly")
using LinearAlgebra
using Random
using Plots
using PlutoUI
md""" **0) import packages** """
end
# ╔═╡ 53b7c092-6e9b-11eb-0fa7-57fc1385457c
md"
# _The Gaussian Process_
Given a vector of pivot points $\boldsymbol x$,
the corresponding random vector $\boldsymbol y(\boldsymbol x)$
is drawn from a multivariate Gaussian
$\boldsymbol z \sim {\cal G}(\mu,C)$
with zero mean $\boldsymbol \mu = \boldsymbol0$
and a covariance with matrix elements: $C_{ij} = C(x_i,x_j)$.
The different Gaussian processes differ in the kernel function $C(x,y)$
1) Straight line process: $C(x,y) = \alpha \cdot x\cdot y$
2) Wiener process: $C(x,y) = \alpha * min(x, y)$
3) Squared exponential kernel: $C(x,y) = \alpha * \exp(- \frac{(x-y)^2}{\sigma^2})$
4) Ornstein-Uhlenbeck process: $C(x,y) = \alpha * \exp(- \frac{|x-y|}{\sigma})$
GPs are magic ✨
"
# ╔═╡ 0ac101ba-6ecc-11eb-090e-6b5998c6bbfd
md"
## Setting up the mathematical details
"
# ╔═╡ dca69d60-6e8c-11eb-3305-3fa2cedd5dac
begin
Nslide = 100
L_𝛼 = [0 100 ; 0.01 20 ; 0.01 40 ; 0.1 10]
L_𝜎 = [-1 -1; -1 -1; 0.1 20; 1 20]
L_d_𝛼 = (L_𝛼[:,2]-L_𝛼[:,1])/Nslide
L_d_𝜎 = (L_𝜎[:,2]-L_𝜎[:,1])/Nslide
md"""**2) range of parameters**"""
end
# ╔═╡ 1edc6c32-6e1d-11eb-2f8a-b9dd2aff7b99
begin
L_x0 = [-25:0.5:25;]
md""" **3) pivot points** """
end
# ╔═╡ 21cff596-6ecc-11eb-093c-f97fac9ab7cb
md"
## Try it out!
"
# ╔═╡ 2837bf6e-6eb2-11eb-2b2a-190fa3669085
md"""
**Choose a kernel:** $(@bind kernel Select(["1"=>"Straight lines","2"=>"Wiener process","3"=>"Squared exponential","4"=>"Ornstein-Uhlenbeck process"]))
"""
# ╔═╡ 0bef3462-6e1d-11eb-08f6-e3b2c36a0721
begin
if kernel == "1"
txt = "Straight line";
L_x = L_x0;
K(x, y, 𝛼, sigma) = 𝛼 * x * y;
elseif kernel == "2"
txt = "Wiener process";
L_x = [x for x in L_x0 if x >= 0];
K(x, y, 𝛼, sigma) = 𝛼 * min(x, y);
elseif kernel == "3"
txt = "Squared exponential";
L_x = L_x0;
K(x, y, 𝛼, 𝜎) = 𝛼 * exp(- (x - y)^2/𝜎^2);
elseif kernel == "4"
txt = "Ornstein-Uhlenbeck process";
L_x = L_x0;
K(x, y, 𝜶, 𝜎) = 𝜶 * exp(abs(x - y)/𝜎);
else
error("$(kernel) not supported")
end
n = length(L_x);
md"""**1) kernel definitions**"""
end
# ╔═╡ bc5fc60a-6e92-11eb-0003-b58331ceddcf
ik = parse(Int32,kernel);
# ╔═╡ 6be2bb7a-6eb6-11eb-2cfb-87383e68bf54
md"
**Specify the seed value of random number generator**
1 $(@bind seed Slider(1:1000)) 1000
"
# ╔═╡ 05d7032c-6e25-11eb-0b48-7f9a594b5b88
begin
rng = MersenneTwister(seed)
u = randn(rng,n, 1)
md"""**4) normal random vector zero mean, unit variance**"""
end
# ╔═╡ 4fdde52e-6e8b-11eb-3199-8112d143a31a
md"""
The seed you chose: $(seed)
"""
# ╔═╡ 431df5da-6eb4-11eb-11af-898fdf3601d5
if L_𝜎[ik,2] > 0
md"""**Choose parameters 𝛼:**
$(L_𝛼[ik,1]) $(@bind 𝛼 Slider(L_𝛼[ik,1]:L_d_𝛼[ik]:L_𝛼[ik,2])) $(L_𝛼[ik,2]) **and 𝜎:**
$(L_𝜎[ik,1]) $(@bind 𝜎 Slider(L_𝜎[ik,1]:L_d_𝜎[ik]:L_𝜎[ik,2])) $(L_𝜎[ik,2])
"""
else
𝜎 = L_𝜎[ik,2]
md"""**Choose parameter 𝛼:**
$(L_𝛼[ik,1]) $(@bind 𝛼 Slider(L_𝛼[ik,1]:L_d_𝛼[ik]:L_𝛼[ik,2])) $(L_𝛼[ik,2])
"""
end
# ╔═╡ 2506c99c-6e1d-11eb-31a2-dd1b99a80d8a
begin
C = [K(x, y, 𝛼, 𝜎) for x in L_x, y in L_x];
F = svd(C);
A = F.U * Diagonal(sqrt.(F.S));
z = A * u;
md""" **5) sample from mvG** """
end
# ╔═╡ ac913758-6e27-11eb-1649-37b1be855f5b
if L_𝜎[ik,2] > 0
md" 𝛼 = $(𝛼), 𝜎 = $(𝜎) "
else
md" 𝛼 = $(𝛼)"
end
# ╔═╡ b3e26166-6e1e-11eb-1708-25be9e2a339f
plot(
L_x,
z,
line =(1,1,:line),
grid = true,
title = txt,
palette = :tab10,
legend = :none,
label = false,
marker = :cross,
xlabel = "x",
ylabel = "y",
ylim = (-20, 20),
)
# ╔═╡ Cell order:
# ╟─53b7c092-6e9b-11eb-0fa7-57fc1385457c
# ╟─0ac101ba-6ecc-11eb-090e-6b5998c6bbfd
# ╟─c823a222-6e1c-11eb-18a8-01a1b673d7b0
# ╟─0bef3462-6e1d-11eb-08f6-e3b2c36a0721
# ╟─dca69d60-6e8c-11eb-3305-3fa2cedd5dac
# ╟─1edc6c32-6e1d-11eb-2f8a-b9dd2aff7b99
# ╟─05d7032c-6e25-11eb-0b48-7f9a594b5b88
# ╟─2506c99c-6e1d-11eb-31a2-dd1b99a80d8a
# ╟─21cff596-6ecc-11eb-093c-f97fac9ab7cb
# ╟─2837bf6e-6eb2-11eb-2b2a-190fa3669085
# ╟─bc5fc60a-6e92-11eb-0003-b58331ceddcf
# ╟─6be2bb7a-6eb6-11eb-2cfb-87383e68bf54
# ╟─4fdde52e-6e8b-11eb-3199-8112d143a31a
# ╟─431df5da-6eb4-11eb-11af-898fdf3601d5
# ╟─ac913758-6e27-11eb-1649-37b1be855f5b
# ╟─b3e26166-6e1e-11eb-1708-25be9e2a339f