Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Starcoder2 : KVCache and flash attention (FusedSDPA) enablement #1149

Merged
merged 21 commits into from
Aug 6, 2024

Conversation

abhatkal
Copy link
Contributor

What does this PR do?

Adds KVCache implementation to the Starcoder2 model
Adds Gaudi Flash Attention (SDPA) to StarCoder2 model.

Implementation borrowed from Qwen2 PRs :
#1087
#1033

Validation on optimum-habana text-generation inference :

With PR :
python run_generation.py --model_name_or_path "bigcode/starcoder2-3b" --bf16 --batch_size 100 --use_hpu_graphs --do_sample --prompt "def print_hello_world():" --max_new_tokens 550 --max_input_tokens 30 --use_kv_cache --reuse_cache --use_flash_attention --flash_attention_recompute

07/22/2024 16:32:50 - INFO - __main__ - device: hpu, n_hpu: 0, bf16: True
07/22/2024 16:32:50 - INFO - __main__ - Model initialization took 8.926s                                                                                                                                                                  07/22/2024 16:32:50 - INFO - __main__ - Graph compilation...                                                                                                                                                                              Warming up
07/22/2024 16:32:56 - INFO - __main__ - Time to first token = 1038.1485800025985ms                                                                                                                                                        Warming up
07/22/2024 16:33:01 - INFO - __main__ - Time to first token = 107.46545001165941ms
Warming up
07/22/2024 16:33:05 - INFO - __main__ - Time to first token = 64.5028020371683ms
07/22/2024 16:33:05 - INFO - __main__ - Running generate...
07/22/2024 16:33:10 - INFO - __main__ - Time to first token = 63.07519105030224ms
07/22/2024 16:33:14 - INFO - __main__ - Time to first token = 63.960200001019984ms
07/22/2024 16:33:18 - INFO - __main__ - Time to first token = 65.7379609765485ms
07/22/2024 16:33:23 - INFO - __main__ - Time to first token = 65.4228410567157ms
07/22/2024 16:33:27 - INFO - __main__ - Time to first token = 64.82256192248315ms

input 1: ('def print_hello_world():',)
output 1: ('def print_hello_world():\n    """Print Hello World\n    """\n    print("hello world!")\nfrom setuptools import setup, find_packages\n\nsetup(\n    name="HelloWorld",\n    version="0.1.0",\n    python_requires=">3.7",\n
description="This is Hello World.",\n    packages=find_packages()\n)\nfrom setuptools import find_packages, setup, Extension\n\nextensions = [\n    Extension(\n        name = "say", \n        sources = ["scripts/say/say.c"],\n    ),\n]\n\nsetup(\n    name = "say",\n    version = "1.0",\n    description = "say in python",\n    author = "",\n    license = "MIT",\n    packages = find_packages(),\n    ext_package = "saylib",\n    ext_modules = extensions,\n)\npackage
main\n\nimport "fmt"\n\nfunc main() {\n    println("hello world!")\n}\nfrom setuptools import setup, find_packages\n\nsetup(\n    name="saylib",\n    version="2.0.0",\n    description="Say Lib",\n    license="MIT",\n    packages=find_packages(),\n    package_data={\n       \'saylib\': [\'say.txt\']\n    },\n    include_package_data=True,\n)\nfrom setuptools import setup, Extension\n\ndef main():\n    setup(\n        name="hello_world",\n        version="1.0",\n        ext_modules=[\n            Extension("say",\n                [\'say.c\']\n            ),\n        ],\n    )\n\nmain()\n# python - setuptools________\n\n[![python-hello-world](https://github.com/shin1ogari/python-hello-world/workflows/python-hello-world/badge.svg)](PYTHON-HELLO-WORLD)\n\nhttps://shin1ogari.hatenablog.com/entry/2021/06/17/154143\n\n## How to Install\n\n```\ngit clone https://github.com/shin1ogari/python-hello-world.git\ncd python-hello-world\npython3 setup.py sdist\npython3 -m pip install dist/HelloWorld-0.1.0.tar.gz\n```\n/**\n * @author \n */\n\n/*This function will give the list of tasks in 2 lists, one completed and one not completed\n * @todo : Sort the list on basis of due date \n */\n\n\nconst tasks = [\n    {\n        taskName: "',)

Stats:
---------------------------------------------------------------------------------------------------------------
Throughput (including tokenization) = 12568.517464337727 tokens/second
Number of HPU graphs                      = 19
Memory allocated                               = 9.11 GB
Max memory allocated                       = 10.24 GB
Total memory available                       = 94.62 GB
Graph compilation duration                = 14.794863908027764 seconds
---------------------------------------------------------------------------------------------------------------

Without PR :
python run_generation.py --model_name_or_path "bigcode/starcoder2-3b" --bf16 --batch_size 100 --use_hpu_graphs --do_sample --prompt "def print_hello_world():" --max_new_tokens 550 --max_input_tokens 30 --use_kv_cache --use_flash_attention --flash_attention_recompute

07/22/2024 16:15:16 - INFO - __main__ - device: hpu, n_hpu: 0, bf16: True                                                                                                                                                                 07/22/2024 16:15:16 - INFO - __main__ - Model initialization took 8.410s                                                                                                                                                                  07/22/2024 16:15:16 - INFO - __main__ - Graph compilation...                                                                                                                                                                              Warming up
07/22/2024 16:15:51 - INFO - __main__ - Time to first token = 28370.036993990652ms
Warming up                                                                                                                                                                                                                                07/22/2024 16:15:57 - INFO - __main__ - Time to first token = 1484.4204590190202ms
Warming up
07/22/2024 16:16:03 - INFO - __main__ - Time to first token = 1472.5127129931934ms                                                                                                                                                        07/22/2024 16:16:03 - INFO - __main__ - Running generate...                                                                                                                                                                               07/22/2024 16:16:09 - INFO - __main__ - Time to first token = 1471.8866730108857ms                                                                                                                                                        07/22/2024 16:16:15 - INFO - __main__ - Time to first token = 1471.6727570048533ms
07/22/2024 16:16:20 - INFO - __main__ - Time to first token = 1476.9293259596452ms
07/22/2024 16:16:26 - INFO - __main__ - Time to first token = 1473.4082049690187ms
07/22/2024 16:16:32 - INFO - __main__ - Time to first token = 1473.347980005201ms

input 1: ('def print_hello_world():',)                                                                                                                                                                                               
output 1: ('def print_hello_world():\n    print("Hello, world!")\n\nprint_hello_world() # ____\n\ndef make_square_area(size):\n    return size ** 2\n\nprint(make_square_area(5))\nprint(make_square_area(10))\nprint(make_square_area(13)
)\n\n\n\n# def _________(____):\n#     ____ ____\n#     ____ ____\n#    ...\n#     return ______\n\n# 1. ______ 10____ ____ ________ ______ ____ (________ ____)\n# 2. ______ ________ ____ (______ ____)\n# 3. 1____ 100______ ______ __ __ ____ ________ return____ ____\n\n\n#!/bin/python3\n\nimport math\nimport os\nimport random\nimport re\nimport sys\n\n\n\n#\n# Complete the \'addIntegers\' function below.\n#\n# The function is expected to return an INTEGER.\n# The function accepts following parameters:\n#  1. INTEGER A\n#  2. INTEGER B\n#\n\n# 10 20\n# 10 20   \n# 10 + 20 = 30\n# 10 + 20 = 30\n\n#  def addIntegers(A, B):\n#     for i in range(0, 2):\n#         for j in range(0, 2):\n#
   print(f\'{A[i]} + {B[j]} = {A[i] + B[j]}\')\n\ndef addIntegers(A, B):\n    result = []\n    for i in range(0, A):\n        for j in range(0, B):\n            sum = A[i] + B[j]\n            result.append(sum)\n    return result \n\n\n\n\nif __name__ == \'__main__\':\n    fptr = open(os.environ[\'OUTPUT_PATH\'], \'w\')\n\n    first_multiple_input = input().rstrip().split()\n\n    A = int(first_multiple_input[0])\n\n    B = int(first_multiple_input[1])\n    A_a = []\n    B_a = []\n    for i in range(0, A):\n        A_a.append(i)\n    for i in range(0, B):\n        B_a.append(i)\n    res = addIntegers',)

Stats:
-------------------------------------------------------------------------------------------------------------
Throughput (including tokenization) = 9344.66483451947 tokens/second
Number of HPU graphs                = 14
Memory allocated                    = 24.6 GB
Max memory allocated                = 34.31 GB
Total memory available              = 94.62 GB
Graph compilation duration          = 46.98551622900413 seconds
-------------------------------------------------------------------------------------------------------------

optimum/habana/transformers/models/starcoder2/__init__.py Outdated Show resolved Hide resolved
optimum/habana/transformers/models/__init__.py Outdated Show resolved Hide resolved
super().__init__()

def forward(self, x, y):
return torch.matmul(x, y)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep an eye on: #1148

If this merges first, you can refactor like this

@libinta libinta added review wip synapse1.17 PR that should be available along with Synapse 1.17 but have no dependency on Synapse 1.17 content. labels Jul 24, 2024
@vidyasiv
Copy link
Contributor

@abhatkal , you might want to check/update test_text_generation_example.py for starcoder2 to showcase the improvement with this PR.

@abhatkal
Copy link
Contributor Author

@ssarkar2 @abhilash1910 I observed the below lines consistently result in bad outputs.

        if (
            lazy_mode
            and not self.training
            and (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1)
        ):
            htcore.mark_step()

Commenting them gives better results without causing much difference in throughput. However the ttft increases for the 1st warmup step. Any better way around this ?

Example run :

python run_generation.py --model_name_or_path bigcode/starcoder2-3b --batch_size 1 --use_hpu_graphs --do_sample --prompt "def print_hello_world():" --use_kv_cache --reuse_cache --use_flash_attention --flash_attention_recompute --bf16

Results after commenting the above lines:

Warming up
07/26/2024 09:12:02 - INFO - __main__ - Time to first token = 2216.9843481387943ms
Warming up
07/26/2024 09:12:02 - INFO - __main__ - Time to first token = 65.63230406027287ms
Warming up
07/26/2024 09:12:03 - INFO - __main__ - Time to first token = 9.99257992953062ms
07/26/2024 09:12:03 - INFO - __main__ - Running generate...
07/26/2024 09:12:05 - INFO - __main__ - Time to first token = 8.498925948515534ms

input 1: ('def print_hello_world():',)
output 1: ('def print_hello_world():\n    print("Hello world")\n\n\nif __name__ == "__main__":\n    print_hello_world()/Day16/ Day 16 - Sets and Modules/ Sets and Modules/set.py\nmyset = set()\nprint(myset)\nmyset = set([1,2,3])\nprint(myset)\nmyset = set((1,2,3))\nprint(myset)\nmyset = set({1',)

Stats:
--------------------------------------------------------------------------------------------------------------
Throughput (including tokenization) = 246.3352487081534 tokens/second
Number of HPU graphs                = 17
Memory allocated                    = 6.17 GB
Max memory allocated                = 6.2 GB
Total memory available              = 94.62 GB
Graph compilation duration          = 5.175242085941136 seconds
--------------------------------------------------------------------------------------------------------------

Results with the above lines intact:

Warming up
07/26/2024 09:09:29 - INFO - __main__ - Time to first token = 670.1515450840816ms
Warming up
07/26/2024 09:09:29 - INFO - __main__ - Time to first token = 53.274671896360815ms
Warming up
07/26/2024 09:09:30 - INFO - __main__ - Time to first token = 8.514783112332225ms
07/26/2024 09:09:30 - INFO - __main__ - Running generate...
07/26/2024 09:09:30 - INFO - __main__ - Time to first token = 7.96029600314796ms

Input/outputs:
input 1: ('def print_hello_world():',)
output 1: ('def print_hello_world():\n   \n.8\n       )- = `-9_ 7-)\'\',567_s. `9 =a0: "/(7_6_s_ 1 [ the\n',)

Stats:
--------------------------------------------------------------------------------------------------------------
Throughput (including tokenization) = 251.1309268210929 tokens/second
Number of HPU graphs                = 19
Memory allocated                    = 6.17 GB
Max memory allocated                = 6.18 GB
Total memory available              = 94.62 GB
Graph compilation duration          = 2.4475005640415475 seconds
-------------------------------------------------------------------------------------------------------------- 

@vidyasiv
Copy link
Contributor

vidyasiv commented Aug 2, 2024

@abhatkal , please resolve merge conflicts.
@ssarkar2 @abhilash1910 , please respond to Amit's query

@abhilash1910
Copy link
Contributor

@ssarkar2 @abhilash1910 I observed the below lines consistently result in bad outputs.

        if (
            lazy_mode
            and not self.training
            and (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1)
        ):
            htcore.mark_step()

Commenting them gives better results without causing much difference in throughput. However the ttft increases for the 1st warmup step. Any better way around this ?

Example run :

python run_generation.py --model_name_or_path bigcode/starcoder2-3b --batch_size 1 --use_hpu_graphs --do_sample --prompt "def print_hello_world():" --use_kv_cache --reuse_cache --use_flash_attention --flash_attention_recompute --bf16

Results after commenting the above lines:

Warming up
07/26/2024 09:12:02 - INFO - __main__ - Time to first token = 2216.9843481387943ms
Warming up
07/26/2024 09:12:02 - INFO - __main__ - Time to first token = 65.63230406027287ms
Warming up
07/26/2024 09:12:03 - INFO - __main__ - Time to first token = 9.99257992953062ms
07/26/2024 09:12:03 - INFO - __main__ - Running generate...
07/26/2024 09:12:05 - INFO - __main__ - Time to first token = 8.498925948515534ms

input 1: ('def print_hello_world():',)
output 1: ('def print_hello_world():\n    print("Hello world")\n\n\nif __name__ == "__main__":\n    print_hello_world()/Day16/ Day 16 - Sets and Modules/ Sets and Modules/set.py\nmyset = set()\nprint(myset)\nmyset = set([1,2,3])\nprint(myset)\nmyset = set((1,2,3))\nprint(myset)\nmyset = set({1',)

Stats:
--------------------------------------------------------------------------------------------------------------
Throughput (including tokenization) = 246.3352487081534 tokens/second
Number of HPU graphs                = 17
Memory allocated                    = 6.17 GB
Max memory allocated                = 6.2 GB
Total memory available              = 94.62 GB
Graph compilation duration          = 5.175242085941136 seconds
--------------------------------------------------------------------------------------------------------------

Results with the above lines intact:

Warming up
07/26/2024 09:09:29 - INFO - __main__ - Time to first token = 670.1515450840816ms
Warming up
07/26/2024 09:09:29 - INFO - __main__ - Time to first token = 53.274671896360815ms
Warming up
07/26/2024 09:09:30 - INFO - __main__ - Time to first token = 8.514783112332225ms
07/26/2024 09:09:30 - INFO - __main__ - Running generate...
07/26/2024 09:09:30 - INFO - __main__ - Time to first token = 7.96029600314796ms

Input/outputs:
input 1: ('def print_hello_world():',)
output 1: ('def print_hello_world():\n   \n.8\n       )- = `-9_ 7-)\'\',567_s. `9 =a0: "/(7_6_s_ 1 [ the\n',)

Stats:
--------------------------------------------------------------------------------------------------------------
Throughput (including tokenization) = 251.1309268210929 tokens/second
Number of HPU graphs                = 19
Memory allocated                    = 6.17 GB
Max memory allocated                = 6.18 GB
Total memory available              = 94.62 GB
Graph compilation duration          = 2.4475005640415475 seconds
-------------------------------------------------------------------------------------------------------------- 

This was discussed internally, but from what I gather, htcore.mark_step() cannot be avoided as it is almost equivalent to optimizer.step() in stock torch. I guess we can do away with checking for dist initialization to save some cycles; without the mark_Step , the logits are expected to be incorrect. Considering the graph use case, lazy is traditionally the default mode currently & since this method is only in inference usecase, I guess the entire condition can be removed to let the compiler ,by default, run the mark_step() rather than branch checking. Open to suggestions.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@libinta libinta added run-test Run CI for PRs from external contributors and removed review wip labels Aug 6, 2024
@regisss regisss merged commit 13b6452 into huggingface:main Aug 6, 2024
3 checks passed
regisss added a commit that referenced this pull request Aug 7, 2024
Co-authored-by: Colabrese <[email protected]>
Co-authored-by: Abhilash Majumder <[email protected]>
Co-authored-by: Sayantan Sarkar <[email protected]>
Co-authored-by: regisss <[email protected]>
@regisss
Copy link
Collaborator

regisss commented Aug 7, 2024

There seems to be an issue with this PR. Running the CI test with

GAUDI2_CI=1 pytest tests/test_text_generation_example.py -v -s -k "starcoder2"

before the PR the output is

Input/outputs:                                                                                                                                                                                
input 1: ('DeepSpeed is a machine learning framework',)                                                                                                                                       
output 1: ('DeepSpeed is a machine learning framework for supervised learning on large scale datasets.t_output_dir = os.path.join(output_dir, "output_dir")\nos.makedirs(output_dir, exist_ok=
True)\nos.makedirs(output_dir_with_model_and_data, exist_ok=True)\nos.makedirs(output_dir_with_model_and_data_and_output_dir, exist_ok=True)\nos',)

and after it is

Input/outputs:
input 1: ('DeepSpeed is a machine learning framework',)
output 1: ('DeepSpeed is a machine learning framework for super super fast super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super super',)

Note that this test doesn't use flash attention.

@abhatkal
Copy link
Contributor Author

abhatkal commented Aug 7, 2024

Weirdly just commenting these lines gives back right output :

        if (
            lazy_mode
            and not self.training
            and (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1)
        ):
            htcore.mark_step()

@ssarkar2 @abhilash1910 It seems like htcore.mark_step() is being called 2 times
Once at line 616 (which is fine ) and then at line [624] ( which is affecting the output accuracy )

GAUDI2_CI=1 pytest tests/test_text_generation_example.py -v -s -k "starcoder2"

Input/outputs:
input 1: ('DeepSpeed is a machine learning framework',)
output 1: ('DeepSpeed is a machine learning framework for supervised learning on graphs and beyond.\nt_name_list = []\nfor i in range(len(name_list)):\n    first_name_list.append(name_list[i][0])\nfirst_name_list\nlast_name_list = []\nfor i in range(len(name_list)):\n    last_name_list.append(name_list[i][1])\nlast_name_list',)

Stats:

---------------------------------------------------------------------------------------------------------------

Throughput (including tokenization) = 257.89395843854567 tokens/second
Number of HPU graphs                = 14
Memory allocated                    = 6.19 GB
Max memory allocated                = 6.2 GB
Total memory available              = 94.62 GB
Graph compilation duration          = 4.361379731912166 seconds



PASSED

========================================================= warnings summary ==========================================================../../../../usr/local/lib/python3.10/dist-packages/transformers/utils/hub.py:124
  /usr/local/lib/python3.10/dist-packages/transformers/utils/hub.py:124: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=========================================== 1 passed, 46 deselected, 1 warning in 23.28s ============================================

@abhatkal
Copy link
Contributor Author

abhatkal commented Aug 23, 2024

@abhilash1910 @ssarkar2 Seems like multiple htcore.markstep() is indeed an issue. It was removed for Mixtral in this latest commit: d427f1f

Can I go ahead and do the same for starcoder2 as well? Please refer to my previous comment for more details

@abhilash1910
Copy link
Contributor

@abhilash1910 @ssarkar2 Seems like multiple htcore.markstep() is indeed an issue. It was removed for Mixtral in this latest commit: d427f1f

Can I go ahead and do the same for starcoder2 as well? Please refer to my previous comment for more details

Yes @abhatkal , should be a new PR to address this issue . Best to remove additional htcore.mark_step().

@regisss regisss mentioned this pull request Oct 8, 2024
3 tasks
@regisss
Copy link
Collaborator

regisss commented Oct 8, 2024

Fixed in #1405

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
run-test Run CI for PRs from external contributors synapse1.17 PR that should be available along with Synapse 1.17 but have no dependency on Synapse 1.17 content.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants