-
Notifications
You must be signed in to change notification settings - Fork 0
/
Amsgrad.py
72 lines (60 loc) · 2.49 KB
/
Amsgrad.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# This code implements the Amsgrad optimization algorithm based on the cost function provided by the user.
import sys
import numpy as np
import scipy as sc
from sympy import Symbol, diff, lambdify, sympify
def Amsgrad(cost_function, f):
x = Symbol('x')
print("f(x) = ", cost_function)
f_dash = diff(cost_function, x)
print("df(x)/dx = ", f_dash)
initialApproximation = float(input("\n---> Enter initial approximation: "))
x0 = initialApproximation
errorTolerance = float(input("---> Enter error tolerance: "))
learningRate = float(input("---> Enter learning rate: "))
print("\n---------------------------------------------------------------")
print(" *** Starting Adam")
print(" ---> x0 = ", initialApproximation)
print(" ---> f(x0) = ", f(initialApproximation))
#----------------------------------------------------------------------------------------------------------------------------------------------------
iterationCount = 0
xk = x0
x_prev = 0.0
m0 = 0.0
mk = 0.0
v0 = 0.0
vk = 0.0
vc_0 = 0.0
vc_k = 0.0
b1 = 0.9
b2 = 0.999
epsilon = 10 ** -8
while True:
iterationCount += 1
x_prev = x0
x0 = xk
m0 = mk
v0 = vk
vc_0 = vc_k
fk_dash = (lambdify(x, f_dash, "numpy"))(xk) # Compute the derivative of f and assign it to fk_dash
gt = fk_dash
mk = b1 * m0 + (1 - b1) * gt # Update the first moment
vk = b2 * v0 + (1 - b2) * (gt ** 2) # Update the second moment
vc_k = max(vc_0, vk)
xk = xk - (learningRate / (vc_k ** 0.5 + epsilon)) * mk
if abs(N(xk - x0)) < float(errorTolerance) or abs(N(xk - x_prev)) < 0.1 * float(errorTolerance): # Check convergence condition
break
#----------------------------------------------------------------------------------------------------------------------------------------------------
print(" *** Number of Iterations = ", iterationCount)
print(" ---> Minima is at = ", xk)
print(" ---> Minimum value of Cost Function = ", f(xk))
print("---------------------------------------------------------------\n")
# Code execution section
def main():
x = Symbol('x')
cost_function = input("---> Enter cost function f(x): ").strip()
c_f = sympify(cost_function)
f = lambdify(x, c_f, "numpy")
Amsgrad(c_f, f)
if __name__ == "__main__":
main()