Skip to content

Commit

Permalink
Fix residual summation for the Solar architecture (#723)
Browse files Browse the repository at this point in the history
  • Loading branch information
arnavgarg1 authored Dec 23, 2024
1 parent 69bb989 commit 52710ea
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,12 @@ def forward(
# Note, we use index 1 instead of index 0 since index 0 is used when training is enabled
bskcn_tv = self.config.bskcn_tv[1]
for i, layer in enumerate(self.layers):
# Add residual to hidden states explicitly. We have to do this because the cross-layer
# residuals assume the output hidden state already have the residuals added to it, but the
# LoRAX implementation only adds the residual to the hidden states in the next layer's input_layernorm.
if residual is not None:
hidden_states = hidden_states + residual

if i in self.config.bskcn_1:
bskcn_1 = hidden_states
if i in self.config.bskcn_2:
Expand All @@ -589,9 +595,11 @@ def forward(
if i in self.config.bskcn_4:
hidden_states = (bskcn_2 * bskcn_tv).to(hidden_states.device) + hidden_states * (1 - bskcn_tv)

# Note, we explicitly set residual to None here to skip adding it to the hidden states
# in the input_layernorm layer because we do this explicitly above.
hidden_states, residual = layer(
hidden_states,
residual,
None, # residual
cos,
sin,
cu_seqlen_prefill,
Expand Down
1 change: 1 addition & 0 deletions server/punica_kernels/punica_kernels/bgmv/bgmv_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ void bgmv_kernel(T *__restrict__ Y, const T *__restrict__ X,
f(T, narrow, 14336) \
f(T, narrow, 15360) \
f(T, narrow, 16384) \
f(T, narrow, 17920) \
f(T, narrow, 18944) \
f(T, narrow, 20480) \
f(T, narrow, 22016) \
Expand Down

0 comments on commit 52710ea

Please sign in to comment.