-
Notifications
You must be signed in to change notification settings - Fork 200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Starcoder2 : KVCache and flash attention (FusedSDPA) enablement #1149
Conversation
…into abhatkal/starcoder2
super().__init__() | ||
|
||
def forward(self, x, y): | ||
return torch.matmul(x, y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keep an eye on: #1148
If this merges first, you can refactor like this
@abhatkal , you might want to check/update test_text_generation_example.py for starcoder2 to showcase the improvement with this PR. |
@ssarkar2 @abhilash1910 I observed the below lines consistently result in bad outputs.
Commenting them gives better results without causing much difference in throughput. However the ttft increases for the 1st warmup step. Any better way around this ? Example run : python run_generation.py --model_name_or_path bigcode/starcoder2-3b --batch_size 1 --use_hpu_graphs --do_sample --prompt "def print_hello_world():" --use_kv_cache --reuse_cache --use_flash_attention --flash_attention_recompute --bf16 Results after commenting the above lines:
Results with the above lines intact:
|
@abhatkal , please resolve merge conflicts. |
This was discussed internally, but from what I gather, htcore.mark_step() cannot be avoided as it is almost equivalent to optimizer.step() in stock torch. I guess we can do away with checking for dist initialization to save some cycles; without the mark_Step , the logits are expected to be incorrect. Considering the graph use case, lazy is traditionally the default mode currently & since this method is only in inference usecase, I guess the entire condition can be removed to let the compiler ,by default, run the mark_step() rather than branch checking. Open to suggestions. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Co-authored-by: Colabrese <[email protected]> Co-authored-by: Abhilash Majumder <[email protected]> Co-authored-by: Sayantan Sarkar <[email protected]> Co-authored-by: regisss <[email protected]>
There seems to be an issue with this PR. Running the CI test with
before the PR the output is
and after it is
Note that this test doesn't use flash attention. |
Weirdly just commenting these lines gives back right output :
@ssarkar2 @abhilash1910 It seems like htcore.mark_step() is being called 2 times
|
@abhilash1910 @ssarkar2 Seems like multiple htcore.markstep() is indeed an issue. It was removed for Mixtral in this latest commit: d427f1f Can I go ahead and do the same for starcoder2 as well? Please refer to my previous comment for more details |
Yes @abhatkal , should be a new PR to address this issue . Best to remove additional htcore.mark_step(). |
Fixed in #1405 |
What does this PR do?
Adds KVCache implementation to the Starcoder2 model
Adds Gaudi Flash Attention (SDPA) to StarCoder2 model.
Implementation borrowed from Qwen2 PRs :
#1087
#1033
Validation on optimum-habana text-generation inference :