From 157af7141c640b21304a08475dbd9b5f46045652 Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Wed, 11 Dec 2024 21:53:44 -0800 Subject: [PATCH 01/13] Update tests_bwd.py --- dev/modal/tests_bwd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/modal/tests_bwd.py b/dev/modal/tests_bwd.py index e08e735de..231c5b4d7 100644 --- a/dev/modal/tests_bwd.py +++ b/dev/modal/tests_bwd.py @@ -8,7 +8,7 @@ image = modal.Image.debian_slim(python_version=PYTHON_VERSION).pip_install("uv") -app = modal.App("liger_tests", image=image) +app = modal.App("liger_tests_bwd", image=image) # mount: add local files to the remote container repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path=REMOTE_ROOT_PATH) From 4cd438184373c477a74113964d01458d2188e919 Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Wed, 11 Dec 2024 21:56:31 -0800 Subject: [PATCH 02/13] Update README.md --- README.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 64d89708f..7e0ea15a9 100644 --- a/README.md +++ b/README.md @@ -59,8 +59,9 @@
Latest News 🔥 - - - [2024/12/15] We release LinkedIn Engineering Blog - [Liger-Kernel: Empowering an open source ecosystem of Triton Kernels for Efficient LLM Training](https://www.linkedin.com/blog/engineering/open-source/liger-kernel-open-source-ecosystem-for-efficient-llm-training) + + - [2024/12/11] We release [v0.5.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.5.0): 80% more memory efficient post training losses (DPO, ORPO, CPO, etc)! + - [2024/12/5] We release LinkedIn Engineering Blog - [Liger-Kernel: Empowering an open source ecosystem of Triton Kernels for Efficient LLM Training](https://www.linkedin.com/blog/engineering/open-source/liger-kernel-open-source-ecosystem-for-efficient-llm-training) - [2024/11/6] We release [v0.4.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.4.0): Full AMD support, Tech Report, Modal CI, Llama-3.2-Vision! - [2024/10/21] We have released the tech report of Liger Kernel on Arxiv: https://arxiv.org/pdf/2410.10989 - [2024/9/6] We release v0.2.1 ([X post](https://x.com/liger_kernel/status/1832168197002510649)). 2500+ Stars, 10+ New Contributors, 50+ PRs, 50k Downloads in two weeks! @@ -72,7 +73,7 @@ **Liger Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training. -We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, JSD, and many more. +We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, JSD, and many more. Check out our [deep dive thread](https://x.com/hsu_byron/status/1866577403918917655) ## Supercharge Your Model with Liger Kernel From c495433468db70b01784f7866e1da25d95d12332 Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Wed, 11 Dec 2024 21:57:28 -0800 Subject: [PATCH 03/13] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7e0ea15a9..ef21c81d9 100644 --- a/README.md +++ b/README.md @@ -73,7 +73,7 @@ **Liger Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training. -We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, JSD, and many more. Check out our [deep dive thread](https://x.com/hsu_byron/status/1866577403918917655) +We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655). ## Supercharge Your Model with Liger Kernel From 00c2b35ab32253d6a836306bf369cdca39a76fe3 Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Wed, 11 Dec 2024 22:05:23 -0800 Subject: [PATCH 04/13] Add more post training in readme (#472) ## Summary ## Testing Done - Hardware Type: - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --- README.md | 13 +++++++++++++ docs/images/post-training.png | Bin 0 -> 21724 bytes 2 files changed, 13 insertions(+) create mode 100644 docs/images/post-training.png diff --git a/README.md b/README.md index ef21c81d9..ab5031949 100644 --- a/README.md +++ b/README.md @@ -90,6 +90,19 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and > - Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Data Type = `bf16`, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s. > - Hugging Face models start to OOM at a 4K context length, whereas Hugging Face + Liger Kernel scales up to 16K. +## Optimize post training with Liger Kernel + +![Post Training](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/post-training.png) + +We provide optimized post training kernels like DPO, ORPO, SimPO, and more which can reduce memory usage by up to 80%. You can easily use them as python modules. + +```python +from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss +orpo_loss = LigerFusedLinearORPOLoss() +y = orpo_loss(lm_head.weight, x, target) +``` + + ## Examples | **Use Case** | **Description** | diff --git a/docs/images/post-training.png b/docs/images/post-training.png new file mode 100644 index 0000000000000000000000000000000000000000..c14612739b79178a435dc330fef860e214416557 GIT binary patch literal 21724 zcmeHv2UL^U);2TFh;uEB6{XJDupy`*HDqi=qzzrVqLfgLbO^*59Z^b9QA%h^fY6a1 zDTxRH1Oq~V(9uLf2>~Gl2qDRT!rVL8JL5O^uj{-2TC;#QDSr2i9lYhnt(e96NcS5WaDJiE6<>Q{@)D4HxgbJZd`Bv~;NF z{5ezF#;-O9!?*7|vp#Qq^9Fan-A!9h{k(3Iy3)xGD3L#_)9*U!+YVV83GQ=E<8uo* zw!R(VHN0tl>UhfdOjYxKupYM{gHZ5~7(W}k0j%y@*$d#8yAy}L6cbZE58EmxcI$gh zuzImgohDz2iCwve-V1&?=eOevF|pkl|EXW;vWCqO`BlQxKK95t(qxyuk|B{}OMFuF zf`RdCmJKYIJar1g40~Ep)rKEQQT`9CE6~q>GYNh)9C})l8!+y{Tc^0 zhMMUp7;`it3YbKP3e~ES5N3=2z0s|d=Gm{rwj4h-Y$9VZiyYUYG)O>)9PGw1%pkCu z`m>8oO--Wzcd^A`ERP=fw&OL^ zr3>$FOKiDYwPb}38~j>n?}h<=hm2|t+V1{>OnGovZS0UBA*2p3izxG z0eLUN#cgy z-!@B}TA6l*4yM(bQEJmkn1KG0I-1gQ;7IUz3X79#9F_;sIOR0J*Sh38GP)cynj_#> zIPxljg%7sTMPeolDLKoke-ylv5Ih}_B(k-Lr(7L%q$+gpp0m8eC+PPoq1i{SSLY7`c<8sJAEb_xXeH-yD1;xm*qqoDhLaQKa z?=}%E2?e9KCsAxQR2igNQ?*h~7>HU5WOa(Y+@zVK89BPjP!ng{g>%RL`7&8ta1T4C z`s&wZS4((mD&t-O%7O)BVL4>-0YiQ(ZNq|T?@#L#&;A!8FDO`YL&J)i8uwGh*S74j;1$&suTX zhMG&8)vUW5K6F@Rh9*g+G)kuBsM|BcW)niP#9qMe(!E`VnU7vrmH4-ETJ4EN2pT6+}JGquTqA7+@2 z!j&rnYPIh-vpTnvgt(9qeB8t#vOl?^Lm@^wiwny|Fijlucki6tE|%y8gnEK-gUW7{ z#8={VW@TrCn%|Zt?_@Wss$$Q(W)Q+%P3`@AEP4_ODj6zXb7K_*U4CUJ!FN<*7P z^|>r9xx=O%fn!gsDiv~%1tgPJh(i8s`l+y*42rDmZIfNRnsdI@L7tb`oX`<=THduE ziD(M@;ETlp4{=Im@t9q!=3V3P`M7gbAJwCpTT50M&wp|6;Qb_lyz;R+OuW^G-GcO+ zKNqUurQg4WJh?bCY|F3DgrLmp(w*v>`Hs;>Ei=P4w#4^BN!f?h6qP)u(9zJ@I#&M? z@+70p%x^^j&X_C~t`rYq6(XhmNyttOIqG5WBHNY{kvIM&CtEd^ zd+BqBpK%4WjofU*WCrq#*U0Ljc#TW!7dH{tv*Cwi`SPf@ z1wge1wcWP|zLCkA*+5F*tZ;w89#b?5^om@PNv#S&8q@@{8B-DpCM5*&=vE*?V`;4&87#ID?Za~_6x#(|+0VCULa28sN^ zdi?UEzHqeFJ*6!2zEDKn)k5xp^5HKENx8Y;safTPFT$_#7FoBI>gL9#o1}d_$tUU- zd5#xJ2|Xsb=ZgNxq|U04WjtlL$$Dyb3wdcbwKQhFr#ReSnJh1OFJB3>+N;k|A-8`p zrIv8Bpqhs}OWdP8F48wxvFna?=L9Fiyl4`#hI3j>ER<|S${D+p*p&cT;Iq{WOTzpLxCsKi4y$OT z2#()I{-bSS^6Qpup=I;S$yD!goA0wGXsNHOi9(G+|ZKL46I7KX_fWR zcbRf>sj2EVQnR5L52b{ansaQ3aW&j1Y`?nZQi((|?Vicd^!}#ihodgyF1gG;?W}VC z%Z2P%)GGbltn)KnvNV;{(~!|HumiWAO2^>7inTGiJQc8k49sm+<)S|ZdvmbX`ckMY zouSG3xXANzCoBcHmcN*W&+pht<*Qy0S1nSEM$8F`Vdrn^aO}@ubyHv2N>|8k`Jm$y zYkQHz^9g*FjG-pQ)*yC>DNDOlV|P=(BAMJ;_6SZXRhelsW0#kjz_i`x5*9pN< zlgbh9=0hz`-JB>d7N-b>=A_R8Aw~1*IR5wNtX5V6n}}7z0fTS-u*bmJn=~>^GOL|z zFO2Wi{EALEq{In~V9a{TpMkrF|KuI$hrxYQX}r;4O7rx%eBin#Z&ztD6+dDbUL*KS zer+cs<~l7W(5w6AE7PiTZpTBr{oxPk3>7SzP>F-Pk1QWFT9Q*}>Yq7wE|PwBI@XOEFx_WTp@zUd4Ov;H_u;iR ztNCu5l@v>~KhF^E2emC>x=Q^f8TG@uKxR5q`#yfS*7OKWh zoF@nFVE^D*rXI4}aT|StmZKHgWKt`ET5j+qEXds(9hOlmIwwVQv#6ULphlBSvUbZc zM+o(=ftG6sGdWR8gJkDp-S}W8B{w;+BO1l@=s)QH7!zMDd)C*Em*ZWiK~|?{!_(G7 z;I?Y=#oO5`stFMI9w+f%r zmL4;=^8NDRG^ZQhqNs^@#_#T{U z(`+@fsAizdz^iN4E3)QHFNJh^&yp4YdAC>ZF6tT5VIM?{CA;OKxy&3~&O+;3VvUBk zj(>zQA+eWQOftz2fzY~1m03mfa#<2L%~dTO!9G2ETq=;XX|DNp$YtI=Dt_#kzt$Z@ zo3~A?3VGjhT%|vIVQ%@_j}*LfY*#JsX4nxgwiH9Xi?ExFcH%9pWFwd%PZzf5?C**Pp zMs>1nbU(#A|1wV9qB8BYW&zfFdTV?&`QUxdy?sZ>9?es(Rlls8tHBK)e8#n_U0Eq) z(0AzCn^koCdyaf}*r-EV!H8tn5(<}A^I;d2aXM408)SM2TYQFYg2Q@+cUDt$#xRDj zDWrx-;q<8#QpH_EdOYNkXHCe!icK@4Y+WHPSp$QV%b+QSLYXy`_kzjRzBIjS%c zM{=BFx8s^rXEXb!pQ-{AYq-xyqYwpShh&0{NR}EpLAz7$t3!@=r}kYgHi12huSSz= z9MK_nofYL8oHA(*1!)_RI2=Wy2lL> zS;njk~#O~oXsS!HmAj-dOsHB~v&7#7?apN zUy8ZFw#t$ilNN~P4g@MDw#K-ZqH3(>ndU$xSSmL9UneUeOjy$0(V3AW1z(}nS2bY6fUUMR9F?ef318K((^o?F&!@YPW zECFnBUwkJb*35U%DT$m+GT=F24w9u|Po@wYUTkE$OigyFl}_V@?(4P^R5zU3ha2<-Xs7r~djc}bZS%h%3940$JDVlO=p zXJB{3mN2ld>E)&P?QY227-9x&{FvMo`@k(3m{c5ND;|GB-_Bz@*gDo15#xkB<6Y!q z&WjReE>TuZxs_FQUhVLFc2fB?0^_)wUd&^mS-@+s&1)vUYD(br9YmP#Gg6A83}FJ# zddl_Nfvs~s=D-^dr*Sq1PDRARNVi8})>E2I!*0nquO^*)$4C_qOO%`jS}c48@KJqA zW^i(9EyEv)m7|{U*3ci`a&@RQIsT+!0gSysQD>_nsg6|x8NpoKQ!+2J&ezBCttyZi+Yv`3UuDY@Ont_J^YTs*`G`ICot>|)_+ zgK4F;-`|*&0Y<)wTCa602oB~bAxkx*5S3A62f`NYOLD>_^**q=x|m`~_paK`t3H9} zuZ0Z+KHD58QdVVne?^wjc50)MFFiF#$ZF2Kqt_;i^{z39|>*zLaZ!55$$x(OqO1quqc^OO>5KBE_xOgb`^OQE{AZl zO3{>}>ZryRG_tuE3?+7(#MWZsSGwXJ-&0SJfr`@FBRI&pF+_J~8OgisTdG$()jE0( zao+WD*6IwykK&%TxNE1x-t-RB-V^6>nS8e=)freR*zV+pUG9Y7WNM2pNs?p2ux&T3 zY(Tbc5qHAi%Wkj_dDmo=Y0YT1ae)nyvG!_Ts+S0V4tf1^59-Z5mPS!bqK~~qqbPi< zbV|J@9o_naTfm*c*t{>qT#QX*N#El9nVzcDB?x4q;c}2z8w%2 zY14gZQ4f1Gcsx~<{YOt0U^m=Q_K+)aoG#lAyN|Xyc1qx;Iwhecjt!EZ3#WzEu~TO~ zW7mthJV5V#xz~WV;qfn#%wVROio!5|;fiVcqU@|kx34?!GFa>exJBK5jf3(hr@|UJ1qelTu^TtuntYXLS`)Wh?OtT^{;_Y>Hbf>LCk@+R=|ReQ z@|5kf*Sa>xf{V4b1bmH;UkJ@ReNO#>UZ5%DTbi4b-$j&g2}OiVJ^hwx|H{NfY+~Pk zFn)OJIC9+Mu-6OitYSge>P`x2!AxDSOwY$$$H3(_zo23r1euzm#a z6#}!;!}$<1D+6c@yy|s1<%=6PC?;P`{ET)A$RczUX+?@H+(z&HLDD4$+ZPUjFS|l{ z{eJBl5IG^Y0-GZ_nMH zeB~4mO^5O+P(vvD6g1AV*daA-AyXksp8HTWL20FiOy&A!b4%oTB?})LBqXwJ3#^%kjB23S!^b%v(*=|1M#f0m zb_$ALEc932atjtKlGDk+9k|>@Z71>09E@mcsSLaGCo`Ws` zIf}4a^M-b)aMf=#Y$;0u`EFegQ~gV8o(O?-?;~>6ga^Gw33ZxIi=_#51yYM5m#IAL zM7Dt;2KXmf0lDKf?|$Cc1<)4phA)ErCJ%k-_x`)-LxSa4VGw79k?bS4iPw$OshF5- zsspIrpm>8E@?*Vw-}bbRcj}-2avWtBvkmuz0MH?05vcJvjh!-XPJsyKV&$Wu`~^z> z)nqjv?>DN`FFL}!RmsDFqmbRZWU;$crR1Z7R#sn*&(&!azR%*DKVjSfwLp^&lw*FoD!gVGQ*J`(^J6%7E#QDbblV^ zbv90ilk3YU_8?Ql4i0~Cv7G!DY)b3qWy~>2S0#`_W2p^!52;~A+e@A$Ls^QT4QxU)E z8H+16arf`nFKGU?f>rV8e0-HT&Rkdg@QiwYnm4^EriZy(Qcdx`O{*WfPEHM3(w6>V zPe+ihs+Wvjv~*6R|0Aq-%Y3?!jVf{wMhHL3ceST#Q7WhoT8RhIukPXrCmfm$ZD}F3 z0B=*X4<38c^E@43O(3i;`Qu*kLNi~j+i^$MY)R=Lr4g$zb=AKou+-B?k~&AKxXyx0 zYW*hE`!MD5gj4QG2u=5}RHw3bb$WEY3IK!Ex~lz|41?`B2k(J$IM7Z$X;fPqcS=3i zX$4+Rt+HoLy8n8Lp*mvK&#)WbJ^{L3zGTa`GEmEc6Rua`?ql!G6}$|>haum`qUnS% z9>YMARzHP^XvLq<8)B7)3*MQf%|HIynCcMD<7B>%vKH(|gbgbRXiAo{Nqy^G)58L) zXR4XH$M+ePMDKH%t1;3Td%VzUMJUyb?A;N2F9@8!PSEq!D92UW6hd0Ob#zoX?AOuG zJ(`_t$#-bB2(8Zk)%rx+=&acsv0aMjgQSF&pa(1aDMnS7UnF$A>XR1}hXN;)UFzPG zH`OK-hc4GoX%Z$jHCBY2IELP)=vPf@twZi=5{Mm&1H~${;GTR3RdvPZ)^jp-4ErB& zsoo_wmC}DHfr_O9hs-NlGbp+ObE%|10lY!Q-)0l(K#%n0C z$9n(Gg)q&@L*65`8mR#T@JV1eiv@LbwAOlhrO#-Z96PajlZ7v#C*KM;)|PabYT3&$g!zqTv}{j@ZKbCka1r+N z>fqCN{Rd2jDm<*#YH3jp3mNs#%i1;bHD3iSbi!~HV<=>;6ofcxRb&9Ly8Q5wy3jzI z5+~B5W>8C1E+|9O)=6kTkmL0=XP=XRsMv?HDsI0A#X7s{_jjf9(8FZ4(Uf)_-gG_w z+!l@NAK3yvk5cN~7O%Yp6e+K!`tR_tat!kid9^+5t(;x(hlj8(6PT=OP}rLj&6#zf zOP0{)hxO{F_-+KYakIa74eNB)W7A8AvmRsaN9wd^X6)lHW(k!PH zROdZhT@Gp(Q{#oURoSW@y~lnNYZKX7hx^(%TjWkWie4P7`}m8JOtjWq&O$&CZo+a?s)4l6Bz2zdcuWp*~Y+`Rs(EN4lYt`69 z`~QPF?eu-oK7iv^mkk_3OrD4hesjjGgitb4BKSe*R?~8!QtTkbWct2HKTf5y8~v7= zQOC2i91(BL3w#O!3!rbI1IKSZRJ`@tZ$~I?nI3CwD_qI))w!|dZX^fIaXrn=*Vcgl z^Et~-LxQIoiFr0Cmq*0bxVXwFyXzaToGdgjUjAXfBs~034SEO4GF~OCY4$Emal5nJPze=gY(SE9G!Xb^(3mUxl1n zeo`SFd9!*dFTabwG@4#iJKvgDc3Eq-mc3S37kS9Lo29!z{V&?;PqnlsT=VjgUbSAWUiq`Q!Z$_sHC+c1 z>>dHjBf0r*q1}s?|I>%XhVi^hHrSDo?u zlQXfxp%nhBKD9<9~FhzBX1N9&O#qyM*nHxV*y$?56 z=bCYaD1Mg)Z~98k<=M-q4XITp|6!DrKea@iM0Ke|-=RZ;RX)Z5skSw-cqQ({gXH!f z>NnNUCzuyKQhbtsU4JZZ0;`SDoITW}8VQp@G``~iH{Vt1H3%f+J&frsu#N-5eEin! zMDq2hln>w%iT;@SJ-UN`@kSJAO*!@d(>T>kWJGoCw6Zg?CkBC_J)7MDg);Gkw{GAf z{d%3i4Cu#+ZR_Qe`>pag8nTGScHEUd}~S&9YqxNgMBw4V}v?Ht|8 zfpKWjBZ&UKPvvAbYEvTCo*+RZu|I1C zOy`7;lsfC%)C8~|HFJLX-9__DyCVuy%#0!aA6E*>+VhNHkd+B6WfGU31_JY9Ze z6zQwUn=TWlZ_)VhpuU5*G;0hh!nX5x6Po7hLGPNvoFmUV-H8gS(g$phxV}zkwJ6QH z`NT8VZfDqIm>Eu%H0VQw)UEUsdrQa#{YiKhW^xuw@_eLt^N|vWD8Qdfoz~lEk0~3E zY4t`-YgKCSwNdtJ%A;g|C}ZJHnv%v{pM7@bbmQX)`b3;GO>RI zgl3bkHiBU7ANKGeAl5uNUboQYC*Y1NedbKS@xH&aJ@4Do`MXjZ@F)p zK3u&{z$rF%O+hqW83JH%dyLkeZ>eXh&sSY* zMo*0^+E!*(cu-wF^TOb5l_2;rf^fR)nZLJadp{G*PJx2);KrQgPVDU|i<$oEm zF%LBz#!1@UIMajhu&MI49{@4i6%0rquY>&@P_cP)@KfWEy8nB*SOg=<)l#h#{#mQ! z`WZ0*xdXFwhsS5Ph5TOd`=2TN z^`XB+@E@oJn)BVM{;hd=+ICp*w&R^kxJk(HM?sEod zrk36A*LoNiy@@7@$v|->W$|6>=giT8Smwr-ssPfE%b6fVT)}n-ha?k$K z&n;f%8}`?8AbQk;$@+J{<9uSk*3&P(H*JpDrq=y}RXG@$w@s>WG1Roo1Vm}^$Xq)R zv{}85Z?))z1*;g&HpCp{uNPO`bz}= zzlb39!26NVSwj)aa_}~VK~APj_2g6L$J!_>940Vq$s%;=kCq9vOt@-}JvP%_KL>1_ zzOiX}9 z)nFu+gPP>jxcZ7DY@IKf8MRGH(IxI*K&d$8Blmu9h73Qq-Gf)nzg7C2#Q_N5A2Mdl z9g{P)|B}EJs{9AxT9;aS(;%T;BuU2<0K?*^IQ6E#g7R6BOZ_9R;${}i4g-)WgB4GZ?db&Gu zj!a57EbR$Qn15kv(-qAgp$N!pfCC^ma;dhTxyM!@3aDC&?`xnlC z6BDmw3*1kehfL10|jg% z;WG*+>9+#Fzx&Q3MfXZ4fDHrnlUDpxOaHDyr5DJS(;(3VFT)E)(d%{R7RZIHc{CcX9kJilG zKTch>MueGC3<7u1|LaP$2s^GW{=62)dzvJ_C=83?P= zLrQi^d|IUQh-QiZo6X4izBvI0Wx<{f#Q76wWuy7_K*A~hM+xEcHIkI+CGz$r;gkWFVBO$`%H0N9PjZe2#?==(gM%PoQ4GDSt z(18P}^0Lal0(8Df`U{z%_>q(K(B+vxVzP*0n)fV-aVkBj2+Y0^b$w_0pm>8KlKMpe z1sbXz7aSQ~J&GL0m$UlI6v6<@(?|)$ZOfnQ2x>aDo~}z~0~SdpsMKa9)P6Y7gHK<+ zpT0ZVcDnimPuKZ9sG43%=TEy*M3g~yyH)=M+tHQYE2}xdA0jHn<&nw&_RAZn>ij-h zsqe(}tWYqUkG((l%b_NIYw>F96}kSst1Yps7b`{uB;vB9h3yV6(`kxme1QQ#3-X)l zrk#~MIx4TvHzNo3Si4ycvLn`{P3GU-67`Ws6!I+AAwXbb5uBF`hS5laHAD-eyhe2c5r1 zPG6JHZ)#594`c#44*-UM18ANBjiW4+CvI5Up=<1o@t$B9h2L};e#S?NHPl@!93!qK z+)l^RF!ccAMY#15SFD8Y?MQ5k{Ik@DSuPb`-LA*QtL8$bcYPNOl?%8(sTGNs0}JDziQSh04_SU6K^X^=02a)k9J3z$vRHv)0MKY`>DwJV(BV&( zamNCE9+=|G;V?$}l^YPN`=%Oe#W;dSZIGrV8T%N(;BI z`eRS!(X^NF8>u1=Rg>w#j_o_;vw>5PG9dQ(Ut8?qAyYz#O@=b=tlDz@@C_wO*V3B^ zsg_#$tVe}{ds})%LtIWkCiRS^T7O7s4d(P%Nf_G3dJskQ|9mQ9_zGlN)TwZnG<=cH zfpQgu*V~cwOM;xTx8;uuDq~jOjJ-P@?$hvPH!4ZSKP|M7pzGAVm1LuG(kGmqoPQPZ z1;4k@R-EomzFSb__Lx@eq!d0E6``}C!HF8vPi;50mDwCQH@Xxysf zEJ|M$#eMuz<-O$BLpK^W_dLX`HCp&Xu1t5RyPM^gb%jHw(A8w9U39XsU^gx`DVuct zlVg3SfSOmxZ+d@Qe)URd_e=TJdU%b2gxXPlk z9}>g+i74UECr=t$4(f(6Zr+e4R$cvKgg+B#6?0D8_ptZ(qq}hJ1rBx97#B=US`%E= zICOxLhbRu`%v?p) zd_%g(ddG|mKK}s`b1vnG)aLUkMc)PzB6SZQA^h0WLD%4Zh>JjfYvk@n8)(D}occ?h zC$|I3aAi3&Sl9YjGx+_YM~V~Y6zJsX!K;7mLKr_}v=_?Xq4q^^t67>UYf!TG-Pe=x zw%)nu@0vK;qGmAhswX}=b~cq8#?w*Z6{8}uq==$8g~mR_;m@>yT9;ajd0}euA`>1_ z;n)Eav|2*N*`k3xoaH!bEkm`3$*}*Bfay^l62`ITq# zmmImRR(BbDfHCC1Qd&K$s)Ih>+YyZMNi*41=KY4S816~k3X?gpaDkxFakj-6N+O@1 zIwh@Z!Fx7Ydf}cHw8*Byy`50NoWRB!iTsi{rOFrdq_CcYRkb+2n9`zG9!>kg*4ERq zxhF?t+%|fcT#@TcssS{0kps6;RYpXfTV1sU)0=u6QwOG+oKw^}S%mFYLr;#rQR`4& zbJo0bN2UNP-Wa-k8=82$_T#ICwV^aotIxly37j>OUw$c`*SxKf45eLklKJJnnHilS z0f0!BXpBsPV|cLt!ofxO&{Y*V=*LK_+Pxy0`Z-GFZwr0xm_&Qb1q^x1^aw=rV9G4W%F|64l#k38I0@2S8o&VtzvHow5>M~eJY}R%$3hP)x8n1wHQ-TGWBvE}hoX28W zZ@~jH`~Eu6(BvZ_ms?lF?)j;_&-r)&VO(Os*`gQw~KQ5mWBWF=`AX@`e|l6eADt`(pW*RcA&8 zgD5Q(!4J|ETQ-(-bM;+$Vs9B{U);KX(0?zjoU(gq9rWWmmWg9|rVYuyjY z#$Ssa#B|>e0?sEN5g`{ck@8*?ETn+W7F?#Ne?rv#H2FyXxZWtYJ@9V!MlknBt4kD1 z62SN`H4&|NtqVgGB}5maP-UQxmOC|PTjgyiN*s<_7F<*c>5N)WF9zLk3)`V{+b4nY z7mYCEKy%(go>kou$nv|}FSu#@q(OYMJ&0*OgpZ<7 zOSq6OFw+!uDN=4Z3g#2624B>@(c5*B_0v$*##GQE*Ws7`H?f7@Oru6nGs;7V`sc(O z0A(dPjW;yDh2hc5PFt#-D(@Lc2E(W{DBT>;F<*kohG@G#VL3D=x25fJ(Gi61u- ze~hVFoo^do6;wNbtB~-%tRfV`O#p*0kaGwFt39=JW6*LI)t_-Lj|M0W$5 z2x4S4eDJu|X#eF)d74{K00&|$zbT5c=~`dSD1WBuv2PF2qlA%6l#@?Qhfj~x)nIhc z+NUc}#{89H!Bp}31_IzjSv6pAOmGR*6~PX3lMQ`BD9MatzK`HIa)}R4bd=Op(NWqA z-*>arhMg$7A3CY;Xk!nfn#=!G0|K~+yUiTb3GE_wzWt5*uWvi+;>Lj+IWH=+<)JTA z%4L%fYioebG`$d&v}#5;MN^Nv0qH2(&L#9RU$bam0ckRc{;PHc#XHw}`!GdwyJ$ z)Ms7L49H`i3sh`6)aSpGS~hCU(&R{-S#bBYC2RBL8Q z(RuW^pQf0cEE{!1?5zhSV6QtO)y80yfG#9#ZS@6hm(SjNi6*w?Pi}#(egv(VMdt>Q zY-1@$$8(a@4ZYrAJP~+B^V#x*wZ2MGqhl;Appm)A2}4ABh7Raur=D&m$T|;9inX}@ zw!6acp)>6ELI<3P z&+oia9rQ8`qAiDe{v*3CLXS@GlxO*SSo%G7|KyGI^=cdwf=LhIC{vxq9k6Kohvp6$ ze=p{iLo8Q1FS9d*?BQVM|1n<4G0n2j=7|GufZG1zn@jkWseB3`7g`RRumRaRCtUC` z0;;V-hCx?4{X{@@TL^7?y1I(D)UWT)>2GsbES)-JZ->ZVXM3pggcE>Wd*6pYW?1kgjEIdy@Q3Zo;oA;?#>X90fljnWLz5 Date: Wed, 11 Dec 2024 22:09:36 -0800 Subject: [PATCH 05/13] align post training loss at the center (#473) ## Summary ## Testing Done - Hardware Type: - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index ab5031949..0114eaaf1 100644 --- a/README.md +++ b/README.md @@ -92,7 +92,9 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and ## Optimize post training with Liger Kernel -![Post Training](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/post-training.png) +

+ Post Training +

We provide optimized post training kernels like DPO, ORPO, SimPO, and more which can reduce memory usage by up to 80%. You can easily use them as python modules. From 0bb6c72d03f0fc570b9e954c42b723ab7e0c315f Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Wed, 11 Dec 2024 23:11:13 -0800 Subject: [PATCH 06/13] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 0114eaaf1..32758a071 100644 --- a/README.md +++ b/README.md @@ -90,7 +90,7 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and > - Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Data Type = `bf16`, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s. > - Hugging Face models start to OOM at a 4K context length, whereas Hugging Face + Liger Kernel scales up to 16K. -## Optimize post training with Liger Kernel +## Optimize Post Training with Liger Kernel

Post Training From 21bacccd107b40146e599e2c27af46d7d157f174 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 16 Dec 2024 22:43:02 +0100 Subject: [PATCH 07/13] [Transformer] fix ORPO loss for MOE models (#479) ## Summary Add missing MOE loss when specified in the trainer. - Hardware Type: - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --- src/liger_kernel/transformers/trainer/orpo_trainer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/liger_kernel/transformers/trainer/orpo_trainer.py b/src/liger_kernel/transformers/trainer/orpo_trainer.py index 3605b9f1b..184430ac1 100644 --- a/src/liger_kernel/transformers/trainer/orpo_trainer.py +++ b/src/liger_kernel/transformers/trainer/orpo_trainer.py @@ -125,6 +125,10 @@ def orpo_partial(lm_head, last_hidden_state, concatenated_labels): outputs.last_hidden_state, concatenated_batch["concatenated_labels"], ) + # if aux_loss_enabled, add the aux_loss to the orpo_loss + if self.aux_loss_enabled: + orpo_loss += self.aux_loss_coef * outputs.aux_loss + return orpo_loss, aux_outputs def get_batch_loss_metrics( From ac5667471e24434c378781c5400b19d595d05fd8 Mon Sep 17 00:00:00 2001 From: Shivam Sahni Date: Mon, 16 Dec 2024 22:01:19 -0800 Subject: [PATCH 08/13] fix: correct typos in docstrings (#482) - Fix 'transfomers' to 'transformers' in mixtral.py - Fix 'Emebedding' to 'Embedding' in orpo_trainer.py ## Summary ## Testing Done - Hardware Type: - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: byhsu@linkedin.com --- src/liger_kernel/transformers/model/mixtral.py | 2 +- src/liger_kernel/transformers/trainer/orpo_trainer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/liger_kernel/transformers/model/mixtral.py b/src/liger_kernel/transformers/model/mixtral.py index 22fea53da..145bc78cd 100644 --- a/src/liger_kernel/transformers/model/mixtral.py +++ b/src/liger_kernel/transformers/model/mixtral.py @@ -38,7 +38,7 @@ def lce_forward_deprecated( cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" - Copy paste Mixtral's forward from transfomers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy + Copy paste Mixtral's forward from transformers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy Args: diff --git a/src/liger_kernel/transformers/trainer/orpo_trainer.py b/src/liger_kernel/transformers/trainer/orpo_trainer.py index 184430ac1..04391fa5f 100644 --- a/src/liger_kernel/transformers/trainer/orpo_trainer.py +++ b/src/liger_kernel/transformers/trainer/orpo_trainer.py @@ -17,7 +17,7 @@ class _FSDPForwardRedirection: This is needed in cases where we call a submodule of a FSDP module. For instance, when we want to call only the `LlamaModel` part out of a FSDP-wrapped `LlamaForCausalLM` to get the hidden states without involving GPU-memory-heavy `lm_head` and cross entropy computation, doing this directly (i.e. `model.model.forward()`) - will not work because the first `nn.Emebedding` layer is not independently wrapped as a FSDP module (because of + will not work because the first `nn.Embedding` layer is not independently wrapped as a FSDP module (because of the transformer-based wrapping policy), and not calling it through FSDP root module forward will not all-gather its parameter, thus resulting in "RuntimeError: 'weight' must be 2-D" error. Similarly, if we want to call just the `lm_head` part of a model, we need this trick too to properly get its params all-gathered. From 61eefe9a4429459351979dc7fe1de746fd7ca86f Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 18 Dec 2024 23:19:38 +0100 Subject: [PATCH 09/13] fix chosen_nll_loss in chunked losses (#486) ## Summary Fix the nll loss in the the chunked loses when the model is a decoder only model, by shifting the logits and targets - Hardware Type: - [ ] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --- src/liger_kernel/chunked_loss/cpo_loss.py | 11 ++- src/liger_kernel/chunked_loss/dpo_loss.py | 8 +- .../chunked_loss/fused_linear_preference.py | 92 ++++++++++++------- src/liger_kernel/chunked_loss/orpo_loss.py | 10 +- test/utils.py | 17 +++- 5 files changed, 101 insertions(+), 37 deletions(-) diff --git a/src/liger_kernel/chunked_loss/cpo_loss.py b/src/liger_kernel/chunked_loss/cpo_loss.py index 2b8052e25..1d771753e 100644 --- a/src/liger_kernel/chunked_loss/cpo_loss.py +++ b/src/liger_kernel/chunked_loss/cpo_loss.py @@ -47,6 +47,7 @@ def forward( alpha=1.0, compute_nll_loss=True, compiled=True, + is_encoder_decoder=False, ): return LigerFusedLinearPreferenceBase.forward( ctx, @@ -60,12 +61,13 @@ def forward( beta=beta, compute_nll_loss=compute_nll_loss, compiled=compiled, + is_encoder_decoder=is_encoder_decoder, ) @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] - return *grads, None, None, None, None, None + return *grads, None, None, None, None, None, None class LigerFusedLinearCPOLoss(torch.nn.Module): @@ -80,11 +82,16 @@ def __init__( alpha: float = 1.0, compute_nll_loss: bool = True, compiled: bool = True, + is_encoder_decoder: bool = False, ): """ Args: ignore_index (int): Index to ignore in the loss. beta (float): Weight for the odds ratio loss. + alpha (float): Weight for the NLL loss. + compute_nll_loss (bool): Whether to compute NLL loss. + compiled (bool): Whether to compile the loss function. + is_encoder_decoder (bool): Whether the model is an encoder-decoder model. """ super().__init__() self.ignore_index = ignore_index @@ -92,6 +99,7 @@ def __init__( self.alpha = alpha self.compute_nll_loss = compute_nll_loss self.compiled = compiled + self.is_encoder_decoder = is_encoder_decoder def forward(self, lin_weight, _input, target, bias=None): return LigerFusedLinearCPOFunction.apply( @@ -104,4 +112,5 @@ def forward(self, lin_weight, _input, target, bias=None): self.alpha, self.compute_nll_loss, self.compiled, + self.is_encoder_decoder, ) diff --git a/src/liger_kernel/chunked_loss/dpo_loss.py b/src/liger_kernel/chunked_loss/dpo_loss.py index 5f1b17cf5..082036eb5 100644 --- a/src/liger_kernel/chunked_loss/dpo_loss.py +++ b/src/liger_kernel/chunked_loss/dpo_loss.py @@ -67,6 +67,7 @@ def forward( compute_nll_loss=True, compiled=True, use_ref_model=True, + is_encoder_decoder=False, ): return LigerFusedLinearPreferenceBase.forward( ctx=ctx, @@ -83,12 +84,13 @@ def forward( ref_input=ref_input, ref_weight=ref_weight, ref_bias=ref_bias, + is_encoder_decoder=is_encoder_decoder, ) @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] - return *grads, None, None, None, None, None, None, None, None + return *grads, None, None, None, None, None, None, None, None, None class LigerFusedLinearDPOLoss(torch.nn.Module): @@ -103,6 +105,7 @@ def __init__( compute_nll_loss: bool = True, compiled: bool = True, use_ref_model: bool = False, + is_encoder_decoder: bool = False, ): """ Args: @@ -111,6 +114,7 @@ def __init__( compute_nll_loss (bool): Whether to compute the NLL loss. compiled (bool): Whether to use the torch compiled kernel. use_ref_model (bool): Whether to use a reference model for the DPO loss. + is_encoder_decoder (bool): Whether the model is an encoder-decoder model. """ super().__init__() self.ignore_index = ignore_index @@ -118,6 +122,7 @@ def __init__( self.compute_nll_loss = compute_nll_loss self.compiled = compiled self.use_ref_model = use_ref_model + self.is_encoder_decoder = is_encoder_decoder def forward( self, @@ -142,4 +147,5 @@ def forward( self.compute_nll_loss, self.compiled, self.use_ref_model, + self.is_encoder_decoder, ) diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index fff0791ec..1ede7aca8 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -26,6 +26,7 @@ def forward( ignore_index=-100, alpha=1.0, beta=0.1, + is_encoder_decoder=False, compute_nll_loss=True, compiled=True, use_ref_model=False, @@ -56,6 +57,7 @@ def forward( ignore_index (int): Index to ignore for loss computation. alpha (float): Weight for the NLL loss. beta (float): Weight for the preference loss. + is_encoder_decoder (bool): Whether the model is an encoder-decoder model. compute_nll_loss (bool): Whether to compute NLL loss. compiled (bool): Whether to use torch compile for chunk accumulation. use_ref_model (bool): Whether to use a reference model for the alignment loss. @@ -94,6 +96,7 @@ def forward( use_ref_model=use_ref_model, ref_weight=ref_weight, ref_bias=ref_bias, + is_encoder_decoder=is_encoder_decoder, **loss_kwargs, ) @@ -282,33 +285,48 @@ def chunk_forward( bias=None, ignore_index=-100, compute_nll_loss=True, + is_encoder_decoder=False, ): - len_chosen_chunk = target_chunk.shape[0] // 2 + # Calculate logits and log probabilities logits_chunk = input_chunk @ weight.t() if bias is not None: - logits_chunk = logits_chunk + bias + logits_chunk += bias log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1) + # Split chunk into chosen and rejected portions + len_chosen_chunk = target_chunk.shape[0] // 2 + + # Handle sequence shifting for non-encoder-decoder models + if not is_encoder_decoder: + logits_chunk = logits_chunk[:, :-1] + log_probs_chunk = log_probs_chunk[:, :-1] + target_chunk = target_chunk[:, 1:] + + # Calculate NLL loss for chosen sequences chosen_nll_loss = 0.0 if compute_nll_loss: + chosen_probs = log_probs_chunk[:len_chosen_chunk] + chosen_targets = target_chunk[:len_chosen_chunk] chosen_nll_loss = F.nll_loss( - log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]), - target_chunk[:len_chosen_chunk].view(-1), + chosen_probs.reshape(-1, chosen_probs.shape[-1]), + chosen_targets.reshape(-1), reduction="sum", ignore_index=ignore_index, ) + # Calculate per-token log probabilities loss_mask = target_chunk != ignore_index label_chunk = torch.where(loss_mask, target_chunk, 0) - per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze( -1 ) average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) - chosen_logps = average_log_prob[:len_chosen_chunk] - rejected_logps = average_log_prob[len_chosen_chunk:] - + # Split results for chosen and rejected + chosen_logps, rejected_logps = ( + average_log_prob[:len_chosen_chunk], + average_log_prob[len_chosen_chunk:], + ) chosen_logits = logits_chunk[:len_chosen_chunk] rejected_logits = logits_chunk[len_chosen_chunk:] @@ -331,6 +349,7 @@ def _compute_loss( ignore_index=-100, alpha=1.0, beta=0.1, + is_encoder_decoder=False, compute_nll_loss=True, use_ref_model=False, ref_input_chunk=None, @@ -350,6 +369,7 @@ def _compute_loss( ignore_index (int): Index to ignore for loss computation. alpha (float): Weight for the NLL loss. beta (float): Weight for the preference loss. + is_encoder_decoder (bool): Whether the model is an encoder-decoder model. compute_nll_loss (bool): Whether to compute NLL loss. use_ref_model (bool): Whether to use a reference model for the alignment loss. ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). @@ -369,33 +389,43 @@ def _compute_loss( bias=bias, ignore_index=ignore_index, compute_nll_loss=compute_nll_loss, + is_encoder_decoder=is_encoder_decoder, ) - chosen_nll_loss = ( - chosen_nll_loss - / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() - ) - chosen_logits_mean = chosen_logits.sum() / ( - full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] - ) - rejected_logits_mean = rejected_logits.sum() / ( - full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] - ) + if not is_encoder_decoder: + chosen_nll_loss = ( + chosen_nll_loss + / (full_target[: full_target.shape[0] // 2, 1:] != ignore_index).sum() + ) + chosen_logits_mean = chosen_logits.sum() / ( + full_target.shape[0] // 2 * (input_chunk.shape[1] - 1) * weight.shape[0] + ) + rejected_logits_mean = rejected_logits.sum() / ( + full_target.shape[0] // 2 * (input_chunk.shape[1] - 1) * weight.shape[0] + ) + else: + chosen_nll_loss = ( + chosen_nll_loss + / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() + ) + chosen_logits_mean = chosen_logits.sum() / ( + full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] + ) + rejected_logits_mean = rejected_logits.sum() / ( + full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] + ) if use_ref_model: with torch.no_grad(): - ( - ref_chosen_logps, - ref_rejected_logps, - ref_chosen_logits, - ref_rejected_logits, - ref_chosen_nll_loss, - ) = LigerFusedLinearPreferenceBase.chunk_forward( - ref_input_chunk, - ref_weight, - target_chunk, - ref_bias, - ignore_index=ignore_index, - compute_nll_loss=False, # We don't need NLL loss for the reference model + (ref_chosen_logps, ref_rejected_logps, _, _, _) = ( + LigerFusedLinearPreferenceBase.chunk_forward( + ref_input_chunk, + ref_weight, + target_chunk, + ref_bias, + ignore_index=ignore_index, + compute_nll_loss=False, # We don't need NLL loss for the reference model + is_encoder_decoder=is_encoder_decoder, # assume the ref model is the same family + ) ) loss_kwargs["ref_chosen_logps"] = ref_chosen_logps loss_kwargs["ref_rejected_logps"] = ref_rejected_logps diff --git a/src/liger_kernel/chunked_loss/orpo_loss.py b/src/liger_kernel/chunked_loss/orpo_loss.py index c860d4bd9..7dae8057e 100644 --- a/src/liger_kernel/chunked_loss/orpo_loss.py +++ b/src/liger_kernel/chunked_loss/orpo_loss.py @@ -57,6 +57,7 @@ def forward( beta=0.1, compute_nll_loss=True, compiled=True, + is_encoder_decoder=False, ): return LigerFusedLinearPreferenceBase.forward( ctx=ctx, @@ -69,12 +70,13 @@ def forward( beta=beta, compute_nll_loss=compute_nll_loss, compiled=compiled, + is_encoder_decoder=is_encoder_decoder, ) @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] - return *grads, None, None, None, None + return *grads, None, None, None, None, None class LigerFusedLinearORPOLoss(torch.nn.Module): @@ -88,17 +90,22 @@ def __init__( beta: float = 0.1, compute_nll_loss: bool = True, compiled: bool = True, + is_encoder_decoder: bool = False, ): """ Args: ignore_index (int): Index to ignore in the loss. beta (float): Weight for the odds ratio loss. + compute_nll_loss (bool): Whether to compute NLL loss. + compiled (bool): Whether to compile the loss function. + is_encoder_decoder (bool): Whether the model is an encoder-decoder model. """ super().__init__() self.ignore_index = ignore_index self.beta = beta self.compute_nll_loss = compute_nll_loss self.compiled = compiled + self.is_encoder_decoder = is_encoder_decoder def forward(self, lin_weight, _input, target, bias=None): return LigerFusedLinearORPOFunction.apply( @@ -110,4 +117,5 @@ def forward(self, lin_weight, _input, target, bias=None): self.beta, self.compute_nll_loss, self.compiled, + self.is_encoder_decoder, ) diff --git a/test/utils.py b/test/utils.py index 3d3799ad0..fc114d163 100644 --- a/test/utils.py +++ b/test/utils.py @@ -350,11 +350,13 @@ def __init__( beta: float = 0.1, ignore_index: int = -100, use_ref_model: bool = False, + is_encoder_decoder: bool = False, ): self.alpha = alpha self.beta = beta self.ignore_index = ignore_index self.use_ref_model = use_ref_model + self.is_encoder_decoder = is_encoder_decoder @abstractmethod def alignment_loss(self): @@ -372,7 +374,6 @@ def get_batch_logps( logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) labels: Labels for which to compute the log probabilities. Label tokens with a value of ignore_index are ignored. Shape: (batch_size, sequence_length) average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. - is_encoder_decoder: Whether the model is an encoder-decoder model. Returns: A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. """ @@ -381,6 +382,9 @@ def get_batch_logps( "Logits (batch and sequence length dim) and labels must have the same shape." ) + if not self.is_encoder_decoder: + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() loss_mask = labels != self.ignore_index # dummy token; we'll ignore the losses on these tokens later @@ -440,6 +444,9 @@ def concatenated_forward( def cross_entropy_loss(logits, labels): # Flatten the tokens loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_index) + if not self.is_encoder_decoder: + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() logits = logits.view(-1, logits.shape[-1]) labels = labels.view(-1) # Enable model parallelism @@ -461,8 +468,12 @@ def cross_entropy_loss(logits, labels): chosen_logps = all_logps[:len_chosen] rejected_logps = all_logps[len_chosen:] - chosen_logits = all_logits[:len_chosen] - rejected_logits = all_logits[len_chosen:] + if not self.is_encoder_decoder: + chosen_logits = all_logits[:len_chosen, :-1] + rejected_logits = all_logits[len_chosen:, :-1] + else: + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] return ( chosen_logps, From 7a781b7adf00f515d0d77552c7324fb7261baf51 Mon Sep 17 00:00:00 2001 From: Shivam Sahni Date: Thu, 19 Dec 2024 13:18:23 -0800 Subject: [PATCH 10/13] Revert "fix chosen_nll_loss in chunked losses (#486)" (#489) This reverts commit 61eefe9a4429459351979dc7fe1de746fd7ca86f. ## Summary ## Testing Done - Hardware Type: - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --- src/liger_kernel/chunked_loss/cpo_loss.py | 11 +-- src/liger_kernel/chunked_loss/dpo_loss.py | 8 +- .../chunked_loss/fused_linear_preference.py | 92 +++++++------------ src/liger_kernel/chunked_loss/orpo_loss.py | 10 +- test/utils.py | 17 +--- 5 files changed, 37 insertions(+), 101 deletions(-) diff --git a/src/liger_kernel/chunked_loss/cpo_loss.py b/src/liger_kernel/chunked_loss/cpo_loss.py index 1d771753e..2b8052e25 100644 --- a/src/liger_kernel/chunked_loss/cpo_loss.py +++ b/src/liger_kernel/chunked_loss/cpo_loss.py @@ -47,7 +47,6 @@ def forward( alpha=1.0, compute_nll_loss=True, compiled=True, - is_encoder_decoder=False, ): return LigerFusedLinearPreferenceBase.forward( ctx, @@ -61,13 +60,12 @@ def forward( beta=beta, compute_nll_loss=compute_nll_loss, compiled=compiled, - is_encoder_decoder=is_encoder_decoder, ) @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] - return *grads, None, None, None, None, None, None + return *grads, None, None, None, None, None class LigerFusedLinearCPOLoss(torch.nn.Module): @@ -82,16 +80,11 @@ def __init__( alpha: float = 1.0, compute_nll_loss: bool = True, compiled: bool = True, - is_encoder_decoder: bool = False, ): """ Args: ignore_index (int): Index to ignore in the loss. beta (float): Weight for the odds ratio loss. - alpha (float): Weight for the NLL loss. - compute_nll_loss (bool): Whether to compute NLL loss. - compiled (bool): Whether to compile the loss function. - is_encoder_decoder (bool): Whether the model is an encoder-decoder model. """ super().__init__() self.ignore_index = ignore_index @@ -99,7 +92,6 @@ def __init__( self.alpha = alpha self.compute_nll_loss = compute_nll_loss self.compiled = compiled - self.is_encoder_decoder = is_encoder_decoder def forward(self, lin_weight, _input, target, bias=None): return LigerFusedLinearCPOFunction.apply( @@ -112,5 +104,4 @@ def forward(self, lin_weight, _input, target, bias=None): self.alpha, self.compute_nll_loss, self.compiled, - self.is_encoder_decoder, ) diff --git a/src/liger_kernel/chunked_loss/dpo_loss.py b/src/liger_kernel/chunked_loss/dpo_loss.py index 082036eb5..5f1b17cf5 100644 --- a/src/liger_kernel/chunked_loss/dpo_loss.py +++ b/src/liger_kernel/chunked_loss/dpo_loss.py @@ -67,7 +67,6 @@ def forward( compute_nll_loss=True, compiled=True, use_ref_model=True, - is_encoder_decoder=False, ): return LigerFusedLinearPreferenceBase.forward( ctx=ctx, @@ -84,13 +83,12 @@ def forward( ref_input=ref_input, ref_weight=ref_weight, ref_bias=ref_bias, - is_encoder_decoder=is_encoder_decoder, ) @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] - return *grads, None, None, None, None, None, None, None, None, None + return *grads, None, None, None, None, None, None, None, None class LigerFusedLinearDPOLoss(torch.nn.Module): @@ -105,7 +103,6 @@ def __init__( compute_nll_loss: bool = True, compiled: bool = True, use_ref_model: bool = False, - is_encoder_decoder: bool = False, ): """ Args: @@ -114,7 +111,6 @@ def __init__( compute_nll_loss (bool): Whether to compute the NLL loss. compiled (bool): Whether to use the torch compiled kernel. use_ref_model (bool): Whether to use a reference model for the DPO loss. - is_encoder_decoder (bool): Whether the model is an encoder-decoder model. """ super().__init__() self.ignore_index = ignore_index @@ -122,7 +118,6 @@ def __init__( self.compute_nll_loss = compute_nll_loss self.compiled = compiled self.use_ref_model = use_ref_model - self.is_encoder_decoder = is_encoder_decoder def forward( self, @@ -147,5 +142,4 @@ def forward( self.compute_nll_loss, self.compiled, self.use_ref_model, - self.is_encoder_decoder, ) diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index 1ede7aca8..fff0791ec 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -26,7 +26,6 @@ def forward( ignore_index=-100, alpha=1.0, beta=0.1, - is_encoder_decoder=False, compute_nll_loss=True, compiled=True, use_ref_model=False, @@ -57,7 +56,6 @@ def forward( ignore_index (int): Index to ignore for loss computation. alpha (float): Weight for the NLL loss. beta (float): Weight for the preference loss. - is_encoder_decoder (bool): Whether the model is an encoder-decoder model. compute_nll_loss (bool): Whether to compute NLL loss. compiled (bool): Whether to use torch compile for chunk accumulation. use_ref_model (bool): Whether to use a reference model for the alignment loss. @@ -96,7 +94,6 @@ def forward( use_ref_model=use_ref_model, ref_weight=ref_weight, ref_bias=ref_bias, - is_encoder_decoder=is_encoder_decoder, **loss_kwargs, ) @@ -285,48 +282,33 @@ def chunk_forward( bias=None, ignore_index=-100, compute_nll_loss=True, - is_encoder_decoder=False, ): - # Calculate logits and log probabilities + len_chosen_chunk = target_chunk.shape[0] // 2 logits_chunk = input_chunk @ weight.t() if bias is not None: - logits_chunk += bias + logits_chunk = logits_chunk + bias log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1) - # Split chunk into chosen and rejected portions - len_chosen_chunk = target_chunk.shape[0] // 2 - - # Handle sequence shifting for non-encoder-decoder models - if not is_encoder_decoder: - logits_chunk = logits_chunk[:, :-1] - log_probs_chunk = log_probs_chunk[:, :-1] - target_chunk = target_chunk[:, 1:] - - # Calculate NLL loss for chosen sequences chosen_nll_loss = 0.0 if compute_nll_loss: - chosen_probs = log_probs_chunk[:len_chosen_chunk] - chosen_targets = target_chunk[:len_chosen_chunk] chosen_nll_loss = F.nll_loss( - chosen_probs.reshape(-1, chosen_probs.shape[-1]), - chosen_targets.reshape(-1), + log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]), + target_chunk[:len_chosen_chunk].view(-1), reduction="sum", ignore_index=ignore_index, ) - # Calculate per-token log probabilities loss_mask = target_chunk != ignore_index label_chunk = torch.where(loss_mask, target_chunk, 0) + per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze( -1 ) average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) - # Split results for chosen and rejected - chosen_logps, rejected_logps = ( - average_log_prob[:len_chosen_chunk], - average_log_prob[len_chosen_chunk:], - ) + chosen_logps = average_log_prob[:len_chosen_chunk] + rejected_logps = average_log_prob[len_chosen_chunk:] + chosen_logits = logits_chunk[:len_chosen_chunk] rejected_logits = logits_chunk[len_chosen_chunk:] @@ -349,7 +331,6 @@ def _compute_loss( ignore_index=-100, alpha=1.0, beta=0.1, - is_encoder_decoder=False, compute_nll_loss=True, use_ref_model=False, ref_input_chunk=None, @@ -369,7 +350,6 @@ def _compute_loss( ignore_index (int): Index to ignore for loss computation. alpha (float): Weight for the NLL loss. beta (float): Weight for the preference loss. - is_encoder_decoder (bool): Whether the model is an encoder-decoder model. compute_nll_loss (bool): Whether to compute NLL loss. use_ref_model (bool): Whether to use a reference model for the alignment loss. ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). @@ -389,43 +369,33 @@ def _compute_loss( bias=bias, ignore_index=ignore_index, compute_nll_loss=compute_nll_loss, - is_encoder_decoder=is_encoder_decoder, ) - if not is_encoder_decoder: - chosen_nll_loss = ( - chosen_nll_loss - / (full_target[: full_target.shape[0] // 2, 1:] != ignore_index).sum() - ) - chosen_logits_mean = chosen_logits.sum() / ( - full_target.shape[0] // 2 * (input_chunk.shape[1] - 1) * weight.shape[0] - ) - rejected_logits_mean = rejected_logits.sum() / ( - full_target.shape[0] // 2 * (input_chunk.shape[1] - 1) * weight.shape[0] - ) - else: - chosen_nll_loss = ( - chosen_nll_loss - / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() - ) - chosen_logits_mean = chosen_logits.sum() / ( - full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] - ) - rejected_logits_mean = rejected_logits.sum() / ( - full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] - ) + chosen_nll_loss = ( + chosen_nll_loss + / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() + ) + chosen_logits_mean = chosen_logits.sum() / ( + full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] + ) + rejected_logits_mean = rejected_logits.sum() / ( + full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] + ) if use_ref_model: with torch.no_grad(): - (ref_chosen_logps, ref_rejected_logps, _, _, _) = ( - LigerFusedLinearPreferenceBase.chunk_forward( - ref_input_chunk, - ref_weight, - target_chunk, - ref_bias, - ignore_index=ignore_index, - compute_nll_loss=False, # We don't need NLL loss for the reference model - is_encoder_decoder=is_encoder_decoder, # assume the ref model is the same family - ) + ( + ref_chosen_logps, + ref_rejected_logps, + ref_chosen_logits, + ref_rejected_logits, + ref_chosen_nll_loss, + ) = LigerFusedLinearPreferenceBase.chunk_forward( + ref_input_chunk, + ref_weight, + target_chunk, + ref_bias, + ignore_index=ignore_index, + compute_nll_loss=False, # We don't need NLL loss for the reference model ) loss_kwargs["ref_chosen_logps"] = ref_chosen_logps loss_kwargs["ref_rejected_logps"] = ref_rejected_logps diff --git a/src/liger_kernel/chunked_loss/orpo_loss.py b/src/liger_kernel/chunked_loss/orpo_loss.py index 7dae8057e..c860d4bd9 100644 --- a/src/liger_kernel/chunked_loss/orpo_loss.py +++ b/src/liger_kernel/chunked_loss/orpo_loss.py @@ -57,7 +57,6 @@ def forward( beta=0.1, compute_nll_loss=True, compiled=True, - is_encoder_decoder=False, ): return LigerFusedLinearPreferenceBase.forward( ctx=ctx, @@ -70,13 +69,12 @@ def forward( beta=beta, compute_nll_loss=compute_nll_loss, compiled=compiled, - is_encoder_decoder=is_encoder_decoder, ) @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] - return *grads, None, None, None, None, None + return *grads, None, None, None, None class LigerFusedLinearORPOLoss(torch.nn.Module): @@ -90,22 +88,17 @@ def __init__( beta: float = 0.1, compute_nll_loss: bool = True, compiled: bool = True, - is_encoder_decoder: bool = False, ): """ Args: ignore_index (int): Index to ignore in the loss. beta (float): Weight for the odds ratio loss. - compute_nll_loss (bool): Whether to compute NLL loss. - compiled (bool): Whether to compile the loss function. - is_encoder_decoder (bool): Whether the model is an encoder-decoder model. """ super().__init__() self.ignore_index = ignore_index self.beta = beta self.compute_nll_loss = compute_nll_loss self.compiled = compiled - self.is_encoder_decoder = is_encoder_decoder def forward(self, lin_weight, _input, target, bias=None): return LigerFusedLinearORPOFunction.apply( @@ -117,5 +110,4 @@ def forward(self, lin_weight, _input, target, bias=None): self.beta, self.compute_nll_loss, self.compiled, - self.is_encoder_decoder, ) diff --git a/test/utils.py b/test/utils.py index fc114d163..3d3799ad0 100644 --- a/test/utils.py +++ b/test/utils.py @@ -350,13 +350,11 @@ def __init__( beta: float = 0.1, ignore_index: int = -100, use_ref_model: bool = False, - is_encoder_decoder: bool = False, ): self.alpha = alpha self.beta = beta self.ignore_index = ignore_index self.use_ref_model = use_ref_model - self.is_encoder_decoder = is_encoder_decoder @abstractmethod def alignment_loss(self): @@ -374,6 +372,7 @@ def get_batch_logps( logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) labels: Labels for which to compute the log probabilities. Label tokens with a value of ignore_index are ignored. Shape: (batch_size, sequence_length) average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. + is_encoder_decoder: Whether the model is an encoder-decoder model. Returns: A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. """ @@ -382,9 +381,6 @@ def get_batch_logps( "Logits (batch and sequence length dim) and labels must have the same shape." ) - if not self.is_encoder_decoder: - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() loss_mask = labels != self.ignore_index # dummy token; we'll ignore the losses on these tokens later @@ -444,9 +440,6 @@ def concatenated_forward( def cross_entropy_loss(logits, labels): # Flatten the tokens loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_index) - if not self.is_encoder_decoder: - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() logits = logits.view(-1, logits.shape[-1]) labels = labels.view(-1) # Enable model parallelism @@ -468,12 +461,8 @@ def cross_entropy_loss(logits, labels): chosen_logps = all_logps[:len_chosen] rejected_logps = all_logps[len_chosen:] - if not self.is_encoder_decoder: - chosen_logits = all_logits[:len_chosen, :-1] - rejected_logits = all_logits[len_chosen:, :-1] - else: - chosen_logits = all_logits[:len_chosen] - rejected_logits = all_logits[len_chosen:] + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] return ( chosen_logps, From 3205342a6a7209c55ca3a4bd97e986961fdc792e Mon Sep 17 00:00:00 2001 From: Shivam Sahni Date: Thu, 19 Dec 2024 16:49:14 -0800 Subject: [PATCH 11/13] fix dpo tests: reduce tolerance and change default compute_nll_loss false (#490) ## Summary ## Testing Done - Hardware Type: - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --- src/liger_kernel/chunked_loss/dpo_loss.py | 4 +- test/chunked_loss/test_dpo_loss.py | 69 ++++++++++++++++++++--- test/utils.py | 10 +++- 3 files changed, 69 insertions(+), 14 deletions(-) diff --git a/src/liger_kernel/chunked_loss/dpo_loss.py b/src/liger_kernel/chunked_loss/dpo_loss.py index 5f1b17cf5..cf07e186e 100644 --- a/src/liger_kernel/chunked_loss/dpo_loss.py +++ b/src/liger_kernel/chunked_loss/dpo_loss.py @@ -64,7 +64,7 @@ def forward( ref_bias=None, ignore_index=-100, beta=0.1, - compute_nll_loss=True, + compute_nll_loss=False, compiled=True, use_ref_model=True, ): @@ -100,7 +100,7 @@ def __init__( self, ignore_index: int = -100, beta: float = 0.1, - compute_nll_loss: bool = True, + compute_nll_loss: bool = False, compiled: bool = True, use_ref_model: bool = False, ): diff --git a/test/chunked_loss/test_dpo_loss.py b/test/chunked_loss/test_dpo_loss.py index 0ac8faeb8..b73a69a57 100644 --- a/test/chunked_loss/test_dpo_loss.py +++ b/test/chunked_loss/test_dpo_loss.py @@ -23,10 +23,17 @@ class HFDPOLoss(HFAlignmentLoss): """ def __init__( - self, ignore_index: int = -100, beta: float = 0.1, use_ref_model: bool = True + self, + ignore_index: int = -100, + beta: float = 0.1, + use_ref_model: bool = True, + compute_nll_loss: bool = False, ): super().__init__( - beta=beta, ignore_index=ignore_index, use_ref_model=use_ref_model + beta=beta, + ignore_index=ignore_index, + use_ref_model=use_ref_model, + compute_nll_loss=compute_nll_loss, ) def alignment_loss( @@ -61,6 +68,7 @@ def __init__( dtype: torch.dtype, bias: bool = False, ref_bias: bool = False, + compute_nll_loss: bool = False, ignore_index: int = -100, beta: float = 0.1, ): @@ -72,7 +80,10 @@ def __init__( in_features=H, out_features=V, bias=ref_bias, dtype=dtype ) self.dpo_loss = HFDPOLoss( - ignore_index=ignore_index, beta=beta, use_ref_model=True + ignore_index=ignore_index, + beta=beta, + use_ref_model=True, + compute_nll_loss=compute_nll_loss, ).get_batch_loss_metrics def forward(self, x, ref_x, y): @@ -95,6 +106,7 @@ def __init__( dtype: torch.dtype, bias: bool = False, ref_bias: bool = False, + compute_nll_loss: bool = False, ignore_index: int = -100, beta: float = 0.1, ): @@ -106,7 +118,10 @@ def __init__( in_features=H, out_features=V, bias=ref_bias, dtype=dtype ) self.dpo_loss = LigerFusedLinearDPOLoss( - ignore_index=ignore_index, beta=beta, use_ref_model=True + ignore_index=ignore_index, + beta=beta, + use_ref_model=True, + compute_nll_loss=compute_nll_loss, ) def forward(self, x, ref_x, y): @@ -132,14 +147,27 @@ def forward(self, x, ref_x, y): "scalar, dtype, atol, rtol", [ (1.0, torch.bfloat16, 5e-2, 5e-1), - (1.0, torch.float32, 2e-2, 5e-1), + (1.0, torch.float32, 1e-5, 5e-4), ], ) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("ref_bias", [True, False]) +@pytest.mark.parametrize("compute_nll_loss", [True, False]) @pytest.mark.parametrize("ignore_index, beta", [(-100, 0.1), (42, 0.2)]) def test_correctness( - B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias, ignore_index, beta + B, + T, + H, + V, + scalar, + dtype, + atol, + rtol, + bias, + ref_bias, + compute_nll_loss, + ignore_index, + beta, ): B = 2 * B # dpo loss requires B to be even @@ -149,6 +177,7 @@ def test_correctness( dtype=dtype, bias=bias, ref_bias=ref_bias, + compute_nll_loss=compute_nll_loss, ignore_index=ignore_index, beta=beta, ) @@ -158,6 +187,7 @@ def test_correctness( dtype=dtype, bias=bias, ref_bias=ref_bias, + compute_nll_loss=compute_nll_loss, ignore_index=ignore_index, beta=beta, ) @@ -251,7 +281,10 @@ def test_correctness( ) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("ref_bias", [True, False]) -def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias): +@pytest.mark.parametrize("compute_nll_loss", [True, False]) +def test_correctness_functional( + B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias, compute_nll_loss +): B = 2 * B _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar @@ -290,10 +323,28 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref ref_bias2 = _ref_bias.detach().clone().requires_grad_(True) if ref_bias else None loss1, aggregated_aux_outputs1 = LigerFusedLinearDPOFunction.apply( - input1, weight1, target, bias1, ref_input, ref_weight1, ref_bias1 + input1, + weight1, + target, + bias1, + ref_input, + ref_weight1, + ref_bias1, + -100, + 0.1, + compute_nll_loss, ) loss2, aggregated_aux_outputs2 = liger_fused_linear_dpo( - input2, weight2, target, bias2, ref_input, ref_weight2, ref_bias2 + input2, + weight2, + target, + bias2, + ref_input, + ref_weight2, + ref_bias2, + -100, + 0.1, + compute_nll_loss, ) assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) diff --git a/test/utils.py b/test/utils.py index 3d3799ad0..48fcf3601 100644 --- a/test/utils.py +++ b/test/utils.py @@ -350,11 +350,13 @@ def __init__( beta: float = 0.1, ignore_index: int = -100, use_ref_model: bool = False, + compute_nll_loss: bool = True, ): self.alpha = alpha self.beta = beta self.ignore_index = ignore_index self.use_ref_model = use_ref_model + self.compute_nll_loss = compute_nll_loss @abstractmethod def alignment_loss(self): @@ -448,9 +450,11 @@ def cross_entropy_loss(logits, labels): return loss labels = target - chosen_nll_loss = cross_entropy_loss( - all_logits[:len_chosen], labels[:len_chosen] - ) + chosen_nll_loss = torch.tensor(0.0, device=all_logits.device) + if self.compute_nll_loss: + chosen_nll_loss = cross_entropy_loss( + all_logits[:len_chosen], labels[:len_chosen] + ) all_logps = self.get_batch_logps( all_logits, From 79e2b02a4a4ffafe111c3f8ede7df4fb56db890e Mon Sep 17 00:00:00 2001 From: Chun-Mao Lai <72752478+Mecoli1219@users.noreply.github.com> Date: Fri, 20 Dec 2024 14:08:18 -0800 Subject: [PATCH 12/13] CPO & SimPO add label_smoothing (#493) ## Summary Add label_smoothing support for CPO and SimPO so that they align with the huggingface [interface](https://github.com/huggingface/trl/blob/b668048fe1931c57796ad5ae3f10852337ce7565/trl/trainer/cpo_trainer.py#L645C1-L658C14). ## Testing Done - [x] Something wrong with the unit test. I'll have to fix it - Hardware Type: - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Signed-off-by: Mecoli1219 --- src/liger_kernel/chunked_loss/cpo_loss.py | 18 ++++++++++++--- src/liger_kernel/chunked_loss/simpo_loss.py | 21 ++++++++++++++--- test/chunked_loss/test_cpo_loss.py | 25 +++++++++++++++++++-- test/chunked_loss/test_simpo_loss.py | 24 ++++++++++++++++++-- 4 files changed, 78 insertions(+), 10 deletions(-) diff --git a/src/liger_kernel/chunked_loss/cpo_loss.py b/src/liger_kernel/chunked_loss/cpo_loss.py index 2b8052e25..987f0cdcf 100644 --- a/src/liger_kernel/chunked_loss/cpo_loss.py +++ b/src/liger_kernel/chunked_loss/cpo_loss.py @@ -9,7 +9,9 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase): @staticmethod - def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1): + def preference_loss_fn( + chosen_logps, rejected_logps, full_target, beta=0.1, label_smoothing=0.0 + ): """ Paper: https://arxiv.org/pdf/2401.08417 @@ -30,9 +32,14 @@ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1): rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). full_target (torch.Tensor): Non chunked full target tensor beta (float): Weight for the CPO loss + label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0. """ logits = beta * (chosen_logps - rejected_logps) - loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2) + loss = ( + F.logsigmoid(logits) * (1 - label_smoothing) + + F.logsigmoid(-logits) * label_smoothing + ).sum() / (full_target.shape[0] // 2) + return loss @staticmethod @@ -45,6 +52,7 @@ def forward( ignore_index=-100, beta=0.1, alpha=1.0, + label_smoothing=0.0, compute_nll_loss=True, compiled=True, ): @@ -58,6 +66,7 @@ def forward( ignore_index=ignore_index, alpha=alpha, beta=beta, + label_smoothing=label_smoothing, compute_nll_loss=compute_nll_loss, compiled=compiled, ) @@ -65,7 +74,7 @@ def forward( @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] - return *grads, None, None, None, None, None + return *grads, None, None, None, None, None, None class LigerFusedLinearCPOLoss(torch.nn.Module): @@ -78,6 +87,7 @@ def __init__( ignore_index: int = -100, beta: float = 0.1, alpha: float = 1.0, + label_smoothing: float = 0.0, compute_nll_loss: bool = True, compiled: bool = True, ): @@ -90,6 +100,7 @@ def __init__( self.ignore_index = ignore_index self.beta = beta self.alpha = alpha + self.label_smoothing = label_smoothing self.compute_nll_loss = compute_nll_loss self.compiled = compiled @@ -102,6 +113,7 @@ def forward(self, lin_weight, _input, target, bias=None): self.ignore_index, self.beta, self.alpha, + self.label_smoothing, self.compute_nll_loss, self.compiled, ) diff --git a/src/liger_kernel/chunked_loss/simpo_loss.py b/src/liger_kernel/chunked_loss/simpo_loss.py index 7efa0603d..2dc9f1a6b 100644 --- a/src/liger_kernel/chunked_loss/simpo_loss.py +++ b/src/liger_kernel/chunked_loss/simpo_loss.py @@ -10,7 +10,12 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase): @staticmethod def preference_loss_fn( - chosen_logps, rejected_logps, full_target, beta=0.1, gamma=0.5 + chosen_logps, + rejected_logps, + full_target, + beta=0.1, + gamma=0.5, + label_smoothing=0.0, ): """ Paper: https://arxiv.org/pdf/2405.14734 @@ -33,9 +38,14 @@ def preference_loss_fn( full_target: Non chunked full target tensor beta (float): beta weight gamma (float): gemma margin term + label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0. """ logits = beta * (chosen_logps - rejected_logps) - gamma - loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2) + loss = ( + F.logsigmoid(logits) * (1 - label_smoothing) + + F.logsigmoid(-logits) * label_smoothing + ).sum() / (full_target.shape[0] // 2) + return loss @staticmethod @@ -48,6 +58,7 @@ def forward( ignore_index=-100, beta=0.1, alpha=1.0, + label_smoothing=0.0, compute_nll_loss=False, compiled=True, gamma=0.5, @@ -63,6 +74,7 @@ def forward( ignore_index=ignore_index, alpha=alpha, beta=beta, + label_smoothing=label_smoothing, compiled=compiled, gamma=gamma, ) @@ -70,7 +82,7 @@ def forward( @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] - return *grads, None, None, None, None, None, None + return *grads, None, None, None, None, None, None, None class LigerFusedLinearSimPOLoss(torch.nn.Module): @@ -83,6 +95,7 @@ def __init__( ignore_index: int = -100, beta: float = 0.1, alpha: float = 1.0, + label_smoothing: float = 0.0, compute_nll_loss: bool = True, compiled: bool = True, gamma: float = 0.5, @@ -96,6 +109,7 @@ def __init__( self.ignore_index = ignore_index self.beta = beta self.alpha = alpha + self.label_smoothing = label_smoothing self.compute_nll_loss = compute_nll_loss self.compiled = compiled self.gamma = gamma @@ -109,6 +123,7 @@ def forward(self, lin_weight, _input, target, bias=None): self.ignore_index, self.beta, self.alpha, + self.label_smoothing, self.compute_nll_loss, self.compiled, self.gamma, diff --git a/test/chunked_loss/test_cpo_loss.py b/test/chunked_loss/test_cpo_loss.py index f0fef7734..a0c4050e5 100644 --- a/test/chunked_loss/test_cpo_loss.py +++ b/test/chunked_loss/test_cpo_loss.py @@ -86,6 +86,7 @@ def __init__( ignore_index: int = -100, beta: float = 0.1, alpha: float = 1.0, + label_smoothing: float = 0.0, loss_type: str = "sigmoid", simpo_gamma: float = 0.5, ): @@ -97,6 +98,7 @@ def __init__( ignore_index=ignore_index, beta=beta, loss_type=loss_type, + label_smoothing=label_smoothing, simpo_gamma=simpo_gamma, ).get_batch_loss_metrics @@ -114,13 +116,17 @@ def __init__( ignore_index: int = -100, beta: float = 0.1, alpha: float = 1.0, + label_smoothing: float = 0.0, ): super().__init__() self.lin = torch.nn.Linear( in_features=H, out_features=V, bias=bias, dtype=dtype ) self.cpo_loss = LigerFusedLinearCPOLoss( - ignore_index=ignore_index, beta=beta, alpha=alpha + ignore_index=ignore_index, + beta=beta, + alpha=alpha, + label_smoothing=label_smoothing, ) def forward(self, x, y): @@ -145,8 +151,21 @@ def forward(self, x, y): @pytest.mark.parametrize( "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] ) +@pytest.mark.parametrize("label_smoothing", [0.0, 0.1]) def test_correctness( - B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha + B, + T, + H, + V, + scalar, + dtype, + atol, + rtol, + bias, + ignore_index, + beta, + alpha, + label_smoothing, ): B = 2 * B # cpo loss requires B to be even @@ -157,6 +176,7 @@ def test_correctness( bias=bias, ignore_index=ignore_index, beta=beta, + label_smoothing=label_smoothing, ) liger_lm_head_cpo = LigerLMHeadCPO( H=H, @@ -165,6 +185,7 @@ def test_correctness( bias=bias, ignore_index=ignore_index, beta=beta, + label_smoothing=label_smoothing, ) torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( diff --git a/test/chunked_loss/test_simpo_loss.py b/test/chunked_loss/test_simpo_loss.py index 3d0937c27..eede598fe 100644 --- a/test/chunked_loss/test_simpo_loss.py +++ b/test/chunked_loss/test_simpo_loss.py @@ -25,6 +25,7 @@ def __init__( ignore_index: int = -100, beta: float = 0.1, alpha: float = 1.0, + label_smoothing: float = 0.0, gamma: float = 0.5, ): super().__init__() @@ -32,7 +33,11 @@ def __init__( in_features=H, out_features=V, bias=bias, dtype=dtype ) self.simpo_loss = LigerFusedLinearSimPOLoss( - ignore_index=ignore_index, beta=beta, alpha=alpha, gamma=gamma + ignore_index=ignore_index, + beta=beta, + alpha=alpha, + gamma=gamma, + label_smoothing=label_smoothing, ) def forward(self, x, y): @@ -57,8 +62,21 @@ def forward(self, x, y): @pytest.mark.parametrize( "ignore_index, beta, gamma", [(-100, 0.1, 0.5), (42, 0.2, 0.85)] ) +@pytest.mark.parametrize("label_smoothing", [0.0, 0.1]) def test_correctness( - B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, gamma + B, + T, + H, + V, + scalar, + dtype, + atol, + rtol, + bias, + ignore_index, + beta, + gamma, + label_smoothing, ): B = 2 * B # SimPO loss requires B to be even @@ -70,6 +88,7 @@ def test_correctness( ignore_index=ignore_index, beta=beta, loss_type="simpo", + label_smoothing=label_smoothing, simpo_gamma=gamma, ) liger_lm_head_simpo = LigerLMHeadSimPO( @@ -79,6 +98,7 @@ def test_correctness( bias=bias, ignore_index=ignore_index, beta=beta, + label_smoothing=label_smoothing, gamma=gamma, ) From 15a2f58f06b1972d9d23ad898608398df7a421b0 Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Sat, 21 Dec 2024 07:17:38 +0800 Subject: [PATCH 13/13] Fix Preference Loss and Refactor for Readability (#484) ## Summary Thanks to @winglian and @shivam15s noticed and fixed this https://github.com/linkedin/Liger-Kernel/pull/481. This PR suggests negating the preference loss terms to align with the formulas in the docstrings, while maintaining the base preference structure as `nll_loss + preference_loss`. This would make our loss computations more consistent since both terms would represent losses to be minimized. [UPDATE: It seems like being addressed now in [here](https://github.com/linkedin/Liger-Kernel/commit/3205342a6a7209c55ca3a4bd97e986961fdc792e#diff-3048cb37b97e27515852c200994f3257b8ae33a465421d05184713377c0895b1R150)] This PR also tightened the tolerance in case of encountering a similar issue. ## Testing Done - Hardware Type: - [X] run `make test` to ensure correctness - [X] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Signed-off-by: Austin Liu Co-authored-by: Wing Lian Co-authored-by: Shivam Sahni --- src/liger_kernel/chunked_loss/cpo_loss.py | 4 ++-- src/liger_kernel/chunked_loss/fused_linear_preference.py | 2 +- src/liger_kernel/chunked_loss/orpo_loss.py | 2 +- src/liger_kernel/chunked_loss/simpo_loss.py | 4 ++-- test/chunked_loss/test_cpo_loss.py | 8 ++++---- test/chunked_loss/test_orpo_loss.py | 2 +- test/utils.py | 2 +- 7 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/liger_kernel/chunked_loss/cpo_loss.py b/src/liger_kernel/chunked_loss/cpo_loss.py index 987f0cdcf..dd84a4dbf 100644 --- a/src/liger_kernel/chunked_loss/cpo_loss.py +++ b/src/liger_kernel/chunked_loss/cpo_loss.py @@ -36,8 +36,8 @@ def preference_loss_fn( """ logits = beta * (chosen_logps - rejected_logps) loss = ( - F.logsigmoid(logits) * (1 - label_smoothing) - + F.logsigmoid(-logits) * label_smoothing + - F.logsigmoid(logits) * (1 - label_smoothing) + - F.logsigmoid(-logits) * label_smoothing ).sum() / (full_target.shape[0] // 2) return loss diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index fff0791ec..4eb939a79 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -408,7 +408,7 @@ def _compute_loss( else: preference_loss, aux_outputs = preference_loss_outputs, [] - loss = alpha * chosen_nll_loss - preference_loss + loss = alpha * chosen_nll_loss + preference_loss return_vars = ( chosen_logps, rejected_logps, diff --git a/src/liger_kernel/chunked_loss/orpo_loss.py b/src/liger_kernel/chunked_loss/orpo_loss.py index c860d4bd9..d615212c5 100644 --- a/src/liger_kernel/chunked_loss/orpo_loss.py +++ b/src/liger_kernel/chunked_loss/orpo_loss.py @@ -36,7 +36,7 @@ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1): - torch.log1p(-torch.exp(rejected_logps)) ) ratio = F.logsigmoid(log_odds) - loss = beta * ratio.sum() / (full_target.shape[0] // 2) + loss = -beta * ratio.sum() / (full_target.shape[0] // 2) chosen_rewards = beta * chosen_logps rejected_rewards = beta * rejected_logps diff --git a/src/liger_kernel/chunked_loss/simpo_loss.py b/src/liger_kernel/chunked_loss/simpo_loss.py index 2dc9f1a6b..5d5867252 100644 --- a/src/liger_kernel/chunked_loss/simpo_loss.py +++ b/src/liger_kernel/chunked_loss/simpo_loss.py @@ -42,8 +42,8 @@ def preference_loss_fn( """ logits = beta * (chosen_logps - rejected_logps) - gamma loss = ( - F.logsigmoid(logits) * (1 - label_smoothing) - + F.logsigmoid(-logits) * label_smoothing + - F.logsigmoid(logits) * (1 - label_smoothing) + - F.logsigmoid(-logits) * label_smoothing ).sum() / (full_target.shape[0] // 2) return loss diff --git a/test/chunked_loss/test_cpo_loss.py b/test/chunked_loss/test_cpo_loss.py index a0c4050e5..4090db795 100644 --- a/test/chunked_loss/test_cpo_loss.py +++ b/test/chunked_loss/test_cpo_loss.py @@ -60,14 +60,14 @@ def alignment_loss( if self.loss_type == "sigmoid": # This reduces to Equation 3 from the CPO paper when label_smoothing -> 0. losses = ( - F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) - + F.logsigmoid(-self.beta * logits) * self.label_smoothing + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * logits) * self.label_smoothing ) elif self.loss_type == "simpo": logits = logits - (self.simpo_gamma / self.beta) losses = ( - F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) - + F.logsigmoid(-self.beta * logits) * self.label_smoothing + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * logits) * self.label_smoothing ) else: raise ValueError( diff --git a/test/chunked_loss/test_orpo_loss.py b/test/chunked_loss/test_orpo_loss.py index 9f5d81b18..112d4f05c 100644 --- a/test/chunked_loss/test_orpo_loss.py +++ b/test/chunked_loss/test_orpo_loss.py @@ -57,7 +57,7 @@ def alignment_loss( - torch.log1p(-torch.exp(policy_rejected_logps)) ) ratio = F.logsigmoid(log_odds) - losses = self.beta * ratio + losses = -self.beta * ratio chosen_rewards = self.beta * policy_chosen_logps rejected_rewards = self.beta * policy_rejected_logps diff --git a/test/utils.py b/test/utils.py index 48fcf3601..3d08c4ae3 100644 --- a/test/utils.py +++ b/test/utils.py @@ -515,7 +515,7 @@ def get_batch_loss_metrics( else: losses, aggregated_aux_outputs = alignment_loss_outputs, [] # full loss - loss = policy_nll_loss * self.alpha - losses.mean() + loss = policy_nll_loss * self.alpha + losses.mean() return_vars = ( policy_chosen_logps, policy_rejected_logps,