Skip to content

Commit

Permalink
Merge pull request #55 from Saurabh7/master
Browse files Browse the repository at this point in the history
Some fixes for GP regression
  • Loading branch information
karlnapf committed Jul 14, 2014
2 parents 575f592 + 99be28e commit d914473
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
15 changes: 12 additions & 3 deletions demos/regression/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,10 @@ def _process(feat_train, labels, noise_level, scale, kernel, domain, learn, feat
n_dimensions = 1

likelihood = sg.GaussianLikelihood()
likelihood.set_sigma(noise_level)
if learn == 'ML2':
likelihood.set_sigma(1)
else:
likelihood.set_sigma(noise_level)
covar_parms = np.log([2])
hyperparams = {'covar': covar_parms, 'lik': np.log([1])}

Expand All @@ -194,11 +197,17 @@ def _process(feat_train, labels, noise_level, scale, kernel, domain, learn, feat
zmean = sg.ZeroMean()
if str(inf_select) == 'ExactInferenceMethod':
inf = sg.ExactInferenceMethod(SECF, feat_train, zmean, labels, likelihood)
inf.set_scale(scale)
if learn == 'ML2':
inf.set_scale(1)
else:
inf.set_scale(scale)
elif str(inf_select) == 'FITCInferenceMethod':
if feat_induc != None:
inf = sg.FITCInferenceMethod(SECF, feat_train, zmean, labels, likelihood, feat_induc)
inf.set_scale(scale)
if learn == 'ML2':
inf.set_scale(1)
else:
inf.set_scale(scale)
elif feat_induc == None:
raise ValueError("Argument Error")

Expand Down
12 changes: 10 additions & 2 deletions templates/regression/gaussian_process.html
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
.style("stroke", "lightgrey");
}

function TrainGP(data)
function plotMean(data)
{
json = $.parseJSON(data);

Expand Down Expand Up @@ -55,6 +55,12 @@
$("#TrainGP").attr('disabled', false);
}

function TrainGP(data)
{
svg.selectAll(".heatmap").remove();
plotMean(data)
}

function plot_predictive(raw_data)
{

Expand All @@ -77,6 +83,8 @@
var y_scale = result['y_scale'];
var color_scale = result['color_scale'];

svg.selectAll(".line").remove();

svg.selectAll(".heatmap").remove();

svg.selectAll(".heatmap")
Expand Down Expand Up @@ -107,7 +115,7 @@
.style("stroke", "grey");
$('#legend').html("<span id='lower' style='float:left; color:white;'>" + Math.floor(domain[0]) + "</span><span id='upper' style='float:right; color:white;'>" + Math.ceil(domain[1]) + "</span>") ;

TrainGP(raw_data)
plotMean(raw_data)

}

Expand Down

0 comments on commit d914473

Please sign in to comment.