-
Notifications
You must be signed in to change notification settings - Fork 1
/
BlackboxSVI.jl
55 lines (51 loc) · 2.22 KB
/
BlackboxSVI.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#=
Implementation of blackboxSVI based on work by Duvenaud, and Ranganath. This implementation is a port from the Python implementation shown in the Autograd example.
=#
module BlackboxSVI
using Distributions
#=
Implementation of blackbox stochastic variational inference. This function returns the variational objective given the log
posterior.
Input:
logprob = log posterior.
D = number of parameters.
mc_samples = number of MCMC samples used to approximate the gradient.
Output:
variational_objective = variational log posterior, the gradient of which can be used for gradient descent optimisation.
=#
function black_box_variational_inference(logprob, D, mc_samples)
#=
Vectorised calculation of Gaussian entropy.
Input: vector of log standard deviations.
Output: vector of Gaussian entropies.
=#
function gaussian_entropy(log_std)
return 0.5 * D * (1.0 + log(2*π)) + sum(log_std)
end
#=
Evaluates Evidence LOwer Bound (ELBO) of the parameters.
Input: vector of parameters containing means and standard deviations of each weight in the neural network.
Output: a single ELBO value.
=#
function variational_objective(params)
mean_vals, log_std = unpack_params(params)
# Reparameterisation trick
samples = randn(mc_samples, D) .* exp.(log_std)' .+ mean_vals'
lower_bound = gaussian_entropy(log_std) + mean(logprob(samples))
return -lower_bound
end
return variational_objective
end
#=
This function separates the means and log standard deviations in a 1xD parameters vector.
The function simply splits the input vector into two, treats the first half as means, the second half as log standard deviations.
Input: 1xD vector of floats
Output: 1x(D/2) vector of means, 1x(D/2) vector of log standard deviations.
=#
function unpack_params(parameters)
mean, log_std = parameters[1:Int(length(parameters)/2)], parameters[Int(length(parameters)/2) + 1:end]
return mean, log_std
end
export black_box_variational_inference
export unpack_params
end