forked from nickjhughes/feature-map-stats
-
Notifications
You must be signed in to change notification settings - Fork 0
/
lattice_gradient_descent.m
89 lines (86 loc) · 2.45 KB
/
lattice_gradient_descent.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
function [x, fval] = lattice_gradient_descent(f, x0, max_iters)
%LATTICE_GRADIENT_DESCENT Gradient descent minimisation on a lattice.
%
% [x, fval] = lattice_gradient_descent(f, x0, max)
%
% Return the minimum function value fval at point x of function f starting at
% x0. Only 1D and 2D functions are supported. Max is the maximum number of
% allowed iterations, which defaults to 100.
%
% See also:
% align_images
% Input defaults and validation
d = length(x0);
if d > 2
error('Only 1 and 2 dimensional functions are supported.');
end
if any(round(x0) ~= x0)
error('All function inputs must be integers.');
end
if nargin < 3
max_iters = 100;
end
switch d
% One-dimensional functions
case 1
x = x0;
bestval = f(x);
iters = 0;
while iters < max_iters
left = f(x-1);
right = f(x+1);
grads = bestval - [left, right];
[best_grad, best_dir] = max(grads);
if best_grad < 0
break;
end
switch best_dir
case 1
bestval = left;
x = x - 1;
case 2
bestval = right;
x = x + 1;
end
iters = iters + 1;
end
if iters >= max_iters
error('Maximum number of iterations (%d) reached.', max_iters);
end
fval = bestval;
% Two-dimensional functions
case 2
x = x0;
bestval = f(x);
iters = 0;
while iters < max_iters
left = f([x(1)-1,x(2)]);
right = f([x(1)+1,x(2)]);
up = f([x(1),x(2)+1]);
down = f([x(1),x(2)-1]);
grads = bestval - [left, right, up, down];
[best_grad, best_dir] = max(grads);
if best_grad < 0
break;
end
switch best_dir
case 1
bestval = left;
x = [x(1)-1,x(2)];
case 2
bestval = right;
x = [x(1)+1,x(2)];
case 3
bestval = up;
x = [x(1),x(2)+1];
case 4
bestval = down;
x = [x(1),x(2)-1];
end
iters = iters + 1;
end
if iters >= max_iters
error('Maximum number of iterations (%d) reached.', max_iters);
end
fval = bestval;
end