-
Notifications
You must be signed in to change notification settings - Fork 1
/
main_nn_cl.py
93 lines (65 loc) · 8.06 KB
/
main_nn_cl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# MIT License
# Copyright (c) 2024 Henrik Hose
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from time import time
from dynamicsgeneric import f_rod_generic_m as f
from utils import *
from scipy.integrate import odeint
class Simulator:
def __init__(self, timestep, f):
self.dt = timestep
self.N = 1
def f_pwconst_input(y, t, u0):
x = y
u = u0
return f(x,u)
self.f_pw_const_input = f_pwconst_input
def run(self, x0, u0):
X_traj = odeint(self.f_pw_const_input, x0, np.linspace(0, self.dt, self.N+1), args=tuple([u0],))
return X_traj[-1]
if __name__=='__main__':
N = 25
dt = 160e-3
m1 = 0.02
m2 = 0.04
x0_flat = np.array([-0.029150275832573513,-2.6582845139655964,-0.5333006949571231,-7.5397636927576634])
U_flat = np.array([-3.1990419294028176,-0.794773068138816,1.2531017316535817,-0.22106537511282393,-3.1747048283292076,-0.5857184075490766,7.702761812426266,2.2829876722361386,-1.1493895079346645,-1.0061765988246132,-0.4835169568239018,-0.18159892019485632,-0.0417152079345123,0.017508910825344986,0.03932434657693387,0.0443058601839241,0.04205987631319994,0.037011734420273334,0.031144341578844624,0.025293193048218046,0.01975798896624898,0.014589296916778889,0.00972540994894956,0.0051137789255530095,0.0013242484134987592])
X_flat = np.array([-0.029150275832573513,-2.6582845139655964,-0.5333006949571231,-7.5397636927576634,-0.11567812737956844,-3.8820692547321083,-0.3710871218320493,-6.876666307730988,-0.12962957960070415,-4.7316242960277135,0.029775254595346058,-3.808039767179991,-0.10893270168684664,-5.112972156082754,0.15721719777894766,-0.9890674833859736,-0.1026652742196383,-5.040252273358829,-0.02909383507857863,1.8716022910006231,-0.1373518792686939,-4.496297506202736,-0.279619873884547,4.88365576767026,-0.14001609605106005,-3.479704659952358,0.153098437017297,7.496003057076204,-0.038359802402083955,-2.159042589369665,0.7243284454473713,7.688780381103804,0.01885072832539487,-1.1899986120827508,0.21606045148444042,4.586951268866338,0.020771869626458304,-0.6187834407223295,-0.0866361996890321,2.6527776023707665,0.005291698087658117,-0.3043719002631102,-0.10378441688506705,1.41145438672515,-0.006078202464662078,-0.14468580277423576,-0.05601488889690508,0.6926788778954206,-0.011243703105757558,-0.06795001313271198,-0.020925626307421424,0.3268575742083951,-0.012649737166308978,-0.03200328585425957,-0.0029343821460961626,0.15181291428994034,-0.012266114153331967,-0.015324807197089379,0.00492621250886948,0.07008603650834536,-0.011166655096152485,-0.007597850296244257,0.007728654960763801,0.03232519080785942,-0.009870279858225263,-0.003999040762286438,0.008183332142653363,0.014971816178419015,-0.00861045650662248,-0.0022979263488892847,0.007621738790239718,0.007018754335934126,-0.007482644902724832,-0.0014694464758818356,0.006673741125538314,0.0033730536598768133,-0.006518088525413136,-0.001044300815063289,0.005626117171145289,0.0016950118046194676,-0.005719699784808365,-0.0008074049831362683,0.004600435351791135,0.0009205583980546864,-0.005079296086279481,-0.0006581466650591143,0.0036393588248536265,0.0005789347116663735,-0.004585844403593736,-0.0005446649604291179,0.0027477779321710997,0.0004852832316331228,-0.004229420199639719,-0.0004305533987177648,0.0019123024164723473,0.0006136613722113719,-0.004002448832471721,-0.0002745236357741861,0.0011172570797733906,0.0010528184328218197,-0.0038919295054432234,-2.5453422734502465e-05,0.0004257866070094434,0.0018798384458440195])
J_flat = np.array([110.42603032516227,0.28468080111658445,1.0215166651082577,2.547775855859885,-186.83094321289474,62.82169639468288,0.19305889133863585,0.6377428446848278,1.2223859331939986,-73.87296444407477,-30.405503565954287,-0.10863050901628601,0.20765997438780387,-0.8135831969627296,78.72701799782392,-6.478083456644055,-0.10518873517270577,-0.23915990904970102,0.0040907920278002155,-3.7747767594061563,11.098480940571067,0.2282962566561193,-0.07950766521104492,2.4775410309351993,-169.80432395134957,-211.93425013532357,0.3299403235495759,-0.3581286273768403,1.6714600302348561,12.625620342269185,-65.25634247218203,-0.48013131535236475,-0.5044847206504778,-4.971110094204886,336.9912928810636,200.2881566290515,-0.3887572083870632,1.0763866741417785,-3.053345425559211,30.183870387016647,18.723297378403284,-0.04844918463633108,0.028457183563057604,0.12262270370770728,-34.580092899248555,-30.972857395878233,0.03458837663959513,0.07848969503742477,0.5286713266632036,-16.299677142477254,-24.479596147686827,0.026919554335237128,0.12443136808840574,0.2899624294714738,-3.3675588868573354,-15.653700725240354,0.009526822159472564,0.09979665212849934,0.09578494748245263,2.1485037857681895,-9.503130691939155,-0.0016289587091675717,0.06897028665958164,-0.003609490284257517,3.985761490300411,-6.037248843855839,-0.007057729725804741,0.048649147322906094,-0.046826267139315404,4.310259422459482,-4.049161494199547,-0.00910402599348567,0.03680120336444632,-0.06195287356641376,4.040723985345001,-2.8596908858637557,-0.009393450887257258,0.029570654499966605,-0.06373186060748422,3.561515829072784,-2.1071430931650705,-0.008833025239125693,0.02455789295301284,-0.05944003931788378,3.035884647254655,-1.5990276534754915,-0.007894229630679192,0.02055781608795861,-0.052602631599564395,2.530354150343859,-1.2319637564978443,-0.006813040691846246,0.01700879375560665,-0.044891139033266275,2.0679999443464876,-0.9494197942655285,-0.005702598467926888,0.013655233100745536,-0.037062983002739025,1.652531246055341,-0.7194297695234865,-0.0046127228088715805,0.01037922900507249,-0.029425198799507692,1.2792631096835267,-0.5228244857630977,-0.00356034064803408,0.007127117993768006,-0.022064548148400007,0.9399652991592927,-0.347086112828283,-0.002543485734719613,0.0039098263296339685,-0.014975519834762889,0.6251328307935021,-0.18475193996514397,-0.0015455088748743814,0.0009654882272010312,-0.008221872465404711,0.3286282300565788,-0.05035074069337813,-0.0005920174883263588,-0.0008519816319100466,-0.0025278627280065787,0.08551173571213128])
x0 = x0_flat
U = U_flat.reshape((N, 1)).transpose()
X = X_flat.reshape((N+1, 4)).transpose()
J = J_flat.reshape((N, 5)).transpose()
def f_true(x, u):
return f(x,u,m2)
sim = Simulator(dt, f_true)
N_sim = N
X_cl_wrong_mass = np.zeros((4,N_sim+1))
U_cl_wrong_mass = np.zeros((1,N_sim))
X_cl_lin_corr_mass = np.zeros((4,N_sim+1))
U_cl_lin_corr_mass = np.zeros((1,N_sim))
X_cl_wrong_mass[:,0] = x0
X_cl_lin_corr_mass[:,0] = x0
for i in range(N_sim):
U_cl_wrong_mass[:,i] = np.array(U[:,i])
X_cl_wrong_mass[:,i+1] = np.array(sim.run(X_cl_wrong_mass[:,i], U_cl_wrong_mass[:,i]))
U_cl_lin_corr_mass[:,i] = np.clip(np.array(U[:,i]) + J[0,i]*(m2-m1), -9, 9)
X_cl_lin_corr_mass[:,i+1] = np.array(sim.run(X_cl_lin_corr_mass[:,i], U_cl_lin_corr_mass[:,i]))
labels = ["wrong mass", "lin corr", "MPC prediction"]
U = [U_cl_wrong_mass, U_cl_lin_corr_mass, U]
X = [X_cl_wrong_mass, X_cl_lin_corr_mass, X]
plot_pendulum(np.linspace(0, dt*N_sim, N_sim+1), 9, U, X, labels, latexify=False)