From 50e63d8234e8465eb37e057d08766eaf98a6e79d Mon Sep 17 00:00:00 2001 From: Johannes Schwab PI Sjors Scheres added 22022021 Date: Tue, 31 Oct 2023 15:17:00 +0000 Subject: [PATCH] continuation --- .../optimize_deformations.cpython-310.pyc | Bin 23209 -> 23455 bytes .../deformations/optimize_deformations.py | 75 +++++++++++------- 2 files changed, 47 insertions(+), 28 deletions(-) diff --git a/dynamight/deformations/__pycache__/optimize_deformations.cpython-310.pyc b/dynamight/deformations/__pycache__/optimize_deformations.cpython-310.pyc index 5d60935056228480e4a0a3fbc0e1b31024afa1a5..7817e2225ee96a7f5732f408885bdfa240bcd4be 100644 GIT binary patch delta 7255 zcma)Bdw5humcO?`aLNXn zz$mt9xgZZ==pIEtMZlJo$L@?cJLB$*qYk<<8=Y@vcXW0f*U|Btb^qDH{Z-xDO^4tp z`MU0@Q>V_UbLyN^r;5`T)Z?$JeOt4$`zZM5%$(WWZd1?qeIO%!M7g4b^E=vNjTx`o zY?Ko#%W~TSc1_i1lY>UkXnHW1vMhx!<&E6M-FmursJEQvMee2%QI(~#{E&@Yp32+e z-Elm-S5J>r(StNUYBOgRC~T}XcaAw16c5HJ0zH^ln*@J{iQ97cT13?}kBvK;#h2-G*{kamTK25^glbK337qkH zPO&{j8)szpEf0L!;UO2f*Vtw&%jfbHw8A@&7P9dnRaNM&NIknxe~4V(N~SQS)5aWY zY{g1a-zoaiPN4BlV~%147M{Sq7fN9hqjrnOB=B(G!o%v{3#K=j0C_tmVf&^l>hI3jTsvzAs~}61Mxl+6;4uzefBhEi8=J zl{?&?;x9C=HjcTd1YOIQK0>!95*?{LhHW4so| z<#n?ev;>YS=voqV$=LPSu4Qm)m2XHuK7BcStAIwZnISuGrB=Q%LlJ+>Nuyb5Kcz(B zn~d!ILif*^gXgdZxQ<^Ne(m^eX0!Mf@LCa7x|-wKw%Ed{T7l4%Mi(eFra<9ag}#Yi z2i5a!_z`A$@LOr{{6S5#u~Qr}9+6gR@~*P%Bq86Ou<&Zj!sgc%VOC#$Vev}HY>upn zDy&NEATVXzIe6g#XxOf!QjLa1b!jql()X(y|OuLUVc(`Y;Xf~H>=A&4pzqpM%u$QH@zVSOFV z($l;SRui&)1BXcl@8?{&kETVu@RMmYNq#5w^3zoR)nu)($|`6vO0IhsbI>lDR-v#` zd?!_iplvZ`7L;eMCsk?yos3NdMfW?jDc@#gN`s~B;WTX$by7}Cih%SD#?^u`g}z8D z*0>KAHzC7Uq;JH?s2t{m$SD*?dl#!UrWY0lYNKk@2GpQ`TYnEu@lgHzw_t&>ld00BNj_!Ue)+ehNp` z5+pf(#KJ#m;*CN53j+>qGO~k1OV@RHq0jX75|J?5$jRnXC!6< zF19$4*OT%;TWoxAM{#+&Hie}MYam;uiY)7C@TK2+&KSkTLjoO<4#ekk`Vg7nkQ*ql z)&_4rc@x8&5_1aXSO%Uo7(Fdvw}hVuyp2u+`+|g?2lS%UK!FBFB0l{k=SYd{Cjg95YX*nBf_Bv`U1B(6cVKV&@V{A!D;WlcGs&tYxno9ZB zdR<0S$)E{tDiWN{{egMscpFR}rm1C5n!G{r%}7!7`QePMZyFa$%GG7Y2PKWE8UC44 zMt`Hev|pYpRIn_SpXBHGdHynw=>fxEIwa6XtDs!0+#VLtIkp^8sZ)Cq&Y==veSxh> zjEg$yrtgTcPO3tN6q=(atGaS}C8>ljWh+ErkriT)*0ge{z@ zRa32GCZL|efPp)UV;zMh4#4eECDLllFDt)aqcR#JQhre7m*^5JO)Mygo^ztYFJWdC zbD~(nI9-aS^2@-q7-!4Mig(Zz@yAxOrclZjGeMMUEh+pW zU6jS{MSZ$=H{HpaVGwcjWSB<}47es_>*~)J(1OSwx)>P=Z6P%>DhfSIVSq8T!1MiH z2`|$>-uT~vWn(k_vyu~^&&CvfG48#+47(X#QaNa3zqO%4=^-;DDpM7N#KWq_; zyN)%0Jo>1WN98xpSbq%Rw;BOB3jsJDA%!4DAPU#cXAg(6apOD5f69Nx-_Rd3z8+N4 z5Xb!s|5yHV8YhklsVTJNzd*vm(8GA?1+gE^h21vNXq(<|HhKl1Zf}k zlJ;Q_X+wq<7}-Pfq{JaZoz5I>2zE4wY4X?=y+JG$`56$KO&hV|*~Z$TWmEm)z(P)J zg`Ozvq~zc7-%+cz*5AeD-dr=B3zt?%Uuj$&T2$Za`l*{xY;O zu&THDBdo_I^02g*VOotw2krd#`eAt}8s^a`aUUT&9HHOSN3a`eN&Ye9KaFeFXy&{_ zVQo=6|3fd#e^D&J+VLWR`jlUx7~5>Ye>9IA{t4a=aP;s$=||}kK_${M{V4&B3EHvg zpYs2vE0JSS1;)GV%QUwVhm4O+#A4*ufqN%&T1Qo`_J1h7N)a|WTHpo-wpar>+TX;U z$brAf_nxi#50FuJ;SHyQZDSNlrOw1V4wqw-keP6W#KLc(@n!eo@JJ{nG5{g zA7{IP*&{J~1SW8tec#+eX+A6aCvkYx+|6$oeT?2uqTNf#N0udyv&T_P-Ry(%Zth{zB;wYB(2{d-48v;M?E-NL=Lnv|C(^#rbN!|50a- zFqdy|oQ;t=d`RFnWYVZ{@v9QzJLSBmzWM1e{&Gm-hVCca)q9Fhk}kLy;tA7uVR#AF za9dg_?;t2+?2xc{;>ptvm+Fagt^tl6W{br`XOi`gB&?qvF%#^wH+M3=V-+D>xa2lk z9vL4H4?s!Gp4Ly$dp$5m^%KU95hZ~ApA_49{&H%dEr@LabyBg& zOW5OltsVwLr5jeeM5$-iy4bgP@)3K81eET54nI+qoqNXpfT%ZG_xtqx3e~AL3o}R8#@<3|=-(=vfFep%ch+CiEP=Z9yk#wFR9bWc7wO^1DdY6WlE9gxsV7wa=Y+*x31=f* zGlx*A$-4OLG4T%7V-ycI=d985X4T9;*Pn+aD(Q1PmX4v%*vsce>WR#B8 zJ;q-r-j!*BqJJt@KIvI|T8DtQ8E;RSp}uJ3R$Ns(V%I7bsDp!oVM?p7+23t%+OV-Z zRcrDwUw0bd_9ji+Vho-(E9RSaSXH+f=`+d;y8CoAYps6I`gJYM-D#Vfv<+<=jfFD~ zradAAQ0$8tV{9rJgJ<1miHx~s4O7dX#HcH`Z6j;an>}l^=1pzQzSUc`92Jw~%$*=G zuFfttUYpG-z7o8)py{@6&{n%rs@q%FYdb{WR=2>Fy0O(?x1c+9%{pJxdhI1adsZ~( zM04I)Sverv5VWbyn%1gmUmJTWhoxN`Hv&E7{%|2oZPw*3Wc&zEg7 zoUZ2EW085a=_P(>IecEX!{cGCt34j=CxYisMN?v=H4G~JNPyjkC0BVjZ}biNOyAk5+Izig(jeA zL&KFf!EOhMs3_Png3P+0&dl!4FpfH-V{pF_b!K!N8F%-a-9KjcoBdVYn@)$|r{wFp zr%s)!bLzb6RPoH~>gn@prni6pGzI^z{^#`i;(sXTG9O4ySwBKiLIc`zqt&Tz*=#I1 zT9jU6E4681eHJCtXc|Wk1`=u$_$pq_oxCQbdP?=#z|LikfIWN%jgF{%l|I{O%q%XM zL(6F%Eo65cPv^DzT=wGzg=$Z#dsK_gDR9QQ%zV3DJ8E3bEb;GdbCHv3*4kz%s~7Nx z=ppxmw2a*yR8@s;56@#|`h0S_E11HRb_loDmaoM19ip$NgewKkJ(VTrD-o4VVD|(Q z*u4?EwZ=rOQFGH8R{t(6+Or6dw^}4*Yxysh$XDxytlWrX4P3DxWZ$wwRk%x4+D7Xs z?woLLL{Sxt-$$;14N!`EgnMCl5w>!FzQQI&Y-b%?)e{3t18*WNRhWRJ4jeXkijyJX%HdYEcl zOUM~sD&(LV^9y;V1cV~*(9_&m?tYDS%>06t3T3&L;fwvcmeH#4@`%Fgs5-pd_&U2J zX9WzB$)*Nvyk1{HH3BxO^6&Joq?Ife*aqxwqQfJeffpPXMty3_>qXhW(x}fGY5QQ7~{P= zbLta(pZ2ooD)$=hr(~u;>O`$x=tVE0`vWSpG}Y=&mc4Sw_82f?uQhV@xx&IaqZt7e zAyiOn2cHa3fXa<&xsL2oni^gUZ_0P$;lc2aJ;nOUJ2>ON_lp$WZFSfCo8NF-}1>K7}pv0fh9HTNtKdGSwrgHYnEmD z!51^v+FF%RlGgQo`lKZ2#15T&MD3K(m>?j%!T9;Wv4f#zceoLv?uU9!$nfOp8!&jE zI^R$S4fM~Cs1X}nC&(NSYZC22dpqmKK5b$5_QD>p7$s>FfN!FwV1k1Z3IaMLp{D^I zme2t}M?6DsO<9TQH<3!hotU&!Lc0L%0_0&86v`F4^HM0H79ff7qZWRji8roeLrL1l zi2fw4Q*RdAH+{K4B*YeSuu5uY3yvkQg&np$L<}88zhaQW76lXGAJVv&IU+WO z6cJlj2hZ;wuW{+d^}#u7n&ZQVCWJT2lyQ&sZD+}F>||{pQpFM(GKFj@46%%Sfdm=i z5-a`E2D;CXly8}!^H;LvvE&<uhkz&K?YNNX+xt|0=Ax38SYaTqEHZ0B@$#z`iJ<7XZB^cc6q) zv`O+n>PZB^2+3Sqgj}vI{yt*fFT}(!PAo~}bhlZXFuzW>nW@FLx*?}~t3`uvABIhR zc=V{Excgv4%3+Q?rBTTIlzE6EL8oZDjDmvvNksS4Y{KrTjP52cO<@I;M0+)F7shm30=)y8gFxpng&cw?aQ9DLY3l~o}lOpLaAFC zRkKfHEy;i&AJSnH1(_Y@WTQ?+Z9g?-zx?`w3H z)uIN}sWlIWPJytFz|N12i(;q8x5QXG%|RM8^>$5`W+n6*OJoCrHn>#<;%>DNiFDp- zu8#5|nV*NQ9}@GX3KjeQCRx}*Y1)@mAr})+Z)U(ir{Zk>1D11!{ShVHXygnnxz9sI zG+yNOfXXk@MOJ7Dl&0N=o@1iIFJfjTbD-kEI9-e+@;884XS570%72V5i9hCI4+j%! z?ffVFpZJ>?-DY$SotIrpZ$#80MbL|?)4knLhZXq)NMvtVi3w>VHP06)us&$pJBTl( zrQw})Av^}V1=sM{2=-YB1&pVqu3de~*)-q##>HVpd1=1S;`5@%#uR=*UvGRfEYCg& z{*XL;OwQWiKthmJAag(>CYaN!`_S4>ub5iv5M`>(aa(NZ8nHNaqzxJec zJ0*glpLSKGbN*|Kx*D~*NJVd3)bG(187doTlL!^RjFo=cx6*yZ^l#&t`u-=jv@Mjb zrCV|Go)srGGEUaZIN8WInMIK@g~z6AvuHkL->2~Tn6VsT4;rwa-+PzBb1~9tZBjtis(ysA z=)NzPeP4Vfsh29`vK5$OF58M_O|is`j@VvI17<<>0$U}MK)SY%+W80c0h=&H;hTXK z!CZx$(g4^^n^|K3D23+H5_W{`+q%L+ZYg%Vxo}kTwysSBB zM{Hcv4`KFV$?FsT?-sAs;Py!$ZlCnxc4+6!+Rd{^UicDbfpaZ+MR_MGZ}Yy1quNX7 zn1;L2b$UyFKSSS#eiQoQ>!ItWt&n-I(`Uv{OUnFeUyIyyKRhh;x!=?(jSJZM=lWrJ zZi;IRPECpX0$Jn;eNJCMuP8bBA0XpV+_c6q$6$r&5j+25AIzU9>Sr7A*nl?>eu<*2 z)q=k?40hXm5BR?xdBxPvGaP4f9HIShl7j}~a$K_MbX1xQMcL$zlwy-v~HJC{4 zu_qdmW92y5af^iDH_>?d_gR4l6JqmEutyO+{v=Gl!u-b8JBs?BV2@ep21yVYM;WRH z7!22B7+EGp^d1mBw*9y*Zm9s~?T{<&uvTgZW~apL6d3;rw#zi2)SY!i#I5kSY0B>y zeS$uWquouo+FEKj!Jb4J*3)~6&cyYCNPlw74_%yKdyuwzNcYk&<9Z?bRjfCU?bE|% z4CoQKh2;=B4hO)E~ za^8XOeiH0{5hQUl_Y&^tz1hcc7poZJY0_vKT>u$wNh|ps1Z9*R5*m*^W!m9Vy>X7U zz_G(@xp=^g)Bcf|_EW-U_WevuCvLH$`bqjtFU&DuersaH z^OQf!cPhRz)8&1TP7PMLFD}5deD6>dZAO3#sGagfa)Pb1wHr_zieNnYh$_%5gJHk! z3?yucER^DY4nI+_o$~veT?Fn+CeTHh#hUs0j5xiS= zMifBL;^oqWegtMFbP~zVgr1{!E$9@jv7k<37W6zdSkP&5TTtc;w9W!wr1jxWqj7A3 z|E1%J&#C-|wc~8x++I6gN+Z}A9Ma-S^*r4m9H4N~3>pD_iIVx{4)G4vYZMPQuUMmJ z&4QZ$RX+_)RM21X7&@N5Vdvz-jV!+X>{W}nclBLlRo^n9gZXYVwM(Ck)#5P1txNH2 z-^pHvHtYQ8EjZ1z?d@hIPYlUQLO+<*$D-=@XZybD#nQU#3T?uZvbFycC7r0h@XL}l zRzhfimTl=Bkp2Rx>@^xH9#>UG+%A?Wp~RMU6Oz}>0b9ycg@233`D|P_zZEsBoOPat z8f(TCsjo-3jl0Kg{G@cZIwv~k?nYJJ8hy3wLtD}T{MU$f(AabD*JBO|@UUnWi{^-E z%wa%XDO&wnS5uv@ag@JEE#>#+TDeZr3&obXr@z=JmNY zG&I(CC2gwHnl?6f+3T8`yOPkf)@erZj6+FWtiCn+=8W++b*ph><^+phbmXj&sv0tu z%pR1Hrh-NiZUeQkcXqz9XEvKWM4%-ndy}@tnJ}-_yFq(QaJE%0btX1@eU(eQ64y3( z>NaR+1e=#d^NMKB8sp{+>AzplCe~}3SJVD(te7)0>2)!9!3Y4nD8M(2b8~*I8pis$ zAE>*dlOOzVoBB+&tYVO=o{P?}oUbNb7F3swqYFIhOVQMYb!t-9KL#WWP>tU^2S?8? zDorV9fs~^bb|t%9%)7?r(%up?-Vsfn@xN8W27WHUk45t{(fo&Ku14=#HP%)#Tx{?| ZAi5GB@nUacDebZ#=xLgx+nhty{|8$KEJ^?X diff --git a/dynamight/deformations/optimize_deformations.py b/dynamight/deformations/optimize_deformations.py index b4d7fc5..ab93f49 100644 --- a/dynamight/deformations/optimize_deformations.py +++ b/dynamight/deformations/optimize_deformations.py @@ -176,7 +176,7 @@ def optimize_deformations( particle_dataset, val_indices) lambda_regularization_half1 = cp['regularization_parameter_h1'] lambda_regularization_half1 = cp['regularization_parameter_h2'] - n_warmup_epochs = 0 + n_warmup_epochs = n_warmup_epochs half1_indices = inds_half1 half2_indices = inds_half2 print('continuing training from a given checkpoint file') @@ -283,6 +283,25 @@ def optimize_deformations( encoder_half1, encoder_half2, decoder_half1, decoder_half2 = load_models( checkpoint_file, device, box_size, n_classes ) + decoder_half1.model_positions = torch.nn.Parameter( + initial_points.to(device), requires_grad=True) + decoder_half2.model_positions = torch.nn.Parameter( + initial_points.to(device), requires_grad=True) + decoder_half1.amp = torch.nn.Parameter( + 50 * torch.ones(n_classes, n_points).to(device), requires_grad=False + ) + decoder_half1.ampvar = torch.nn.Parameter( + torch.randn(n_classes, n_points).to(device), requires_grad=True + ) + decoder_half2.amp = torch.nn.Parameter( + 50 * torch.ones(n_classes, n_points).to(device), requires_grad=False + ) + decoder_half2.ampvar = torch.nn.Parameter( + torch.randn(n_classes, n_points).to(device), requires_grad=True + ) + decoder_half1.n_points = n_points + decoder_half2.n_points = n_points + else: encoder_half1 = HetEncoder(box_size, latent_dim, 1).to(device) encoder_half2 = HetEncoder(box_size, latent_dim, 1).to(device) @@ -305,33 +324,33 @@ def optimize_deformations( decoder_half1 = DisplacementDecoder(**decoder_kwargs).to(device) decoder_half2 = DisplacementDecoder(**decoder_kwargs).to(device) - if initialization_mode == ConsensusInitializationMode.MAP: - with mrcfile.open(initial_model) as mrc: - Ivol = torch.tensor(mrc.data) - fits = False - while fits == False: - try: - for decoder in (decoder_half1, decoder_half2): - decoder.initialize_physical_parameters( - reference_volume=Ivol) - summ.add_figure("Data/cons_points_z_half1", - tensor_scatter(decoder_half1.model_positions[:, 0], - decoder_half1.model_positions[:, 1], c=torch.ones(decoder_half1.model_positions.shape[0]), s=3), -1) - summ.add_figure("Data/cons_points_z_half2", - tensor_scatter(decoder_half2.model_positions[:, 0], - decoder_half2.model_positions[:, 1], c=torch.ones(decoder_half2.model_positions.shape[0]), s=3), -1) - fits = True - print('consensus gaussian models initialized') - torch.cuda.empty_cache() - except Exception as error: - torch.cuda.empty_cache() - print( - 'volume too large: change size of output volumes. (If you want the original box size for the output volumes use a bigger gpu.', error) - Ivol = torch.nn.functional.avg_pool3d( - Ivol[None, None], (2, 2, 2)) - Ivol = Ivol[0, 0] - decoder_half1.vol_box = decoder_half1.vol_box//2 - decoder_half2.vol_box = decoder_half2.vol_box//2 + if initialization_mode == ConsensusInitializationMode.MAP: + with mrcfile.open(initial_model) as mrc: + Ivol = torch.tensor(mrc.data) + fits = False + while fits == False: + try: + for decoder in (decoder_half1, decoder_half2): + decoder.initialize_physical_parameters( + reference_volume=Ivol) + summ.add_figure("Data/cons_points_z_half1", + tensor_scatter(decoder_half1.model_positions[:, 0], + decoder_half1.model_positions[:, 1], c=torch.ones(decoder_half1.model_positions.shape[0]), s=3), -1) + summ.add_figure("Data/cons_points_z_half2", + tensor_scatter(decoder_half2.model_positions[:, 0], + decoder_half2.model_positions[:, 1], c=torch.ones(decoder_half2.model_positions.shape[0]), s=3), -1) + fits = True + print('consensus gaussian models initialized') + torch.cuda.empty_cache() + except Exception as error: + torch.cuda.empty_cache() + print( + 'volume too large: change size of output volumes. (If you want the original box size for the output volumes use a bigger gpu.', error) + Ivol = torch.nn.functional.avg_pool3d( + Ivol[None, None], (2, 2, 2)) + Ivol = Ivol[0, 0] + decoder_half1.vol_box = decoder_half1.vol_box//2 + decoder_half2.vol_box = decoder_half2.vol_box//2 if mask_file: with mrcfile.open(mask_file) as mrc: