Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] overhaul memory profiling and fix backward compatibility #10511

Merged
merged 23 commits into from
Dec 16, 2024

Conversation

youkaichao
Copy link
Member

@youkaichao youkaichao commented Nov 21, 2024

fixes #10451 , and clearly explain the memory classification and the procedure.

I also added the initial pytorch memory, to be aligned with the pytorch memory profiler.

the profiling procedure is extracted into vllm/utils , so that we can use it later in v1 too.

Signed-off-by: youkaichao <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
@mgoin mgoin self-requested a review November 21, 2024 02:58
vllm/utils.py Outdated Show resolved Hide resolved
vllm/worker/worker.py Outdated Show resolved Hide resolved
@DarkLight1337
Copy link
Member

cc @joerunde

vllm/utils.py Outdated Show resolved Hide resolved
Copy link

mergify bot commented Nov 23, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @youkaichao.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@JaheimLee
Copy link

any progress? I do need to run multi-instance with one GPU.

@Pydataman
Copy link

any progress? I do need to run multi-instance with one GPU.

is this function available?

@youkaichao
Copy link
Member Author

let me finish it this week.

@joerunde
Copy link
Collaborator

joerunde commented Dec 4, 2024

Thanks for taking this on @youkaichao!

I think the docs for the --gpu-memory-utilization flag should also be updated in this PR to reflect the changes once this is working properly

@mergify mergify bot removed the needs-rebase label Dec 14, 2024
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
"PyTorch activation peak memory\t"
f"{(result.torch_peak_increase_in_bytes / GiB_bytes):.2f}GiB\n"
"available_kv_cache_memory\t"
f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB\n")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some suggestions on making this a bit nicer:

  1. The log would be much easier to parse if the numbers were aligned on the same column, right now they're all over the place
  2. The descriptions are a mix of plain words and variable names, for logs maybe we should just use words. available_kv_cache_memory: -> KV Cache Size: etc.
  3. "Non torch memory" is the one item on this list that I think might not be easily understood by somebody reading the logs. Maybe calling it something a little more generic like "Memory overhead" would be less distracting

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all of the newlines take too much space in the logs. IMO it would be more simple to keep the same single-line comma-separated result as before

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed in #10511 (comment) , PTAL

@joerunde
Copy link
Collaborator

@youkaichao Any chance you can add in a quick test for the profiling context manager itself? As an example the one I wrote up here was very simple to do: https://github.com/vllm-project/vllm/pull/11120/files#diff-33c13e0b177bacd2f02e29bcb8aea5b49e7ce34901fd8f41fefb65defba1bd33R277-R312

@joerunde
Copy link
Collaborator

@youkaichao 🤔🤔🤔 Loading facebook/opt-125 twice in the same process on an A100 measures a negative non_torch_memory value:

from vllm import LLM
m1 = LLM("facebook/opt-125m", gpu_memory_utilization=0.25)
m2 = LLM("facebook/opt-125m", gpu_memory_utilization=0.25)
...
INFO 12-16 15:20:32 worker.py:243] non_torch_memory	-0.02GiB

Might not be super important to fix- I think the main use case to unblock here is multi-process vllm serving. But it is interesting, I can't immediately see why that would happen

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent work Kaichao, I appreciate the walkthrough example in memory_profiling. This passed my local usage and I didn't see the issue Joe saw, or think it is a serious issue. My only nit is on adding all the newlines to the log, I think it was fine as comma-separated list

"PyTorch activation peak memory\t"
f"{(result.torch_peak_increase_in_bytes / GiB_bytes):.2f}GiB\n"
"available_kv_cache_memory\t"
f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB\n")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all of the newlines take too much space in the logs. IMO it would be more simple to keep the same single-line comma-separated result as before

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 16, 2024
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
@youkaichao
Copy link
Member Author

@mgoin changed the logging to be:

INFO 12-16 10:14:30 worker.py:241] Memory profiling takes 1.01 seconds
INFO 12-16 10:14:30 worker.py:241] the current vLLM instance can use total_gpu_memory (79.22GiB) x gpu_memory_utilization (0.90) = 71.29GiB
INFO 12-16 10:14:30 worker.py:241] model weights take 14.96GiB; non_torch_memory takes 0.18GiB; PyTorch activation peak memory takes 1.26GiB; the rest of the memory reserved for KV Cache is 54.90GiB.

Let me know if you have further ideas on how to improve the readability.

@youkaichao
Copy link
Member Author

@youkaichao 🤔🤔🤔 Loading facebook/opt-125 twice in the same process on an A100 measures a negative non_torch_memory value:

from vllm import LLM
m1 = LLM("facebook/opt-125m", gpu_memory_utilization=0.25)
m2 = LLM("facebook/opt-125m", gpu_memory_utilization=0.25)
...
INFO 12-16 15:20:32 worker.py:243] non_torch_memory	-0.02GiB

Might not be super important to fix- I think the main use case to unblock here is multi-process vllm serving. But it is interesting, I can't immediately see why that would happen

@joerunde this is because PyTorch's internal memory fragmentation. If PyTorch allocates 2MiB from cuda, and allocate 1MiB only, then this 1 MiB will be accounted as non-torch memory. And when you run it the next time, maybe you allocate another 1 MiB, and the internal memory fragmentation reduces.

Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
@youkaichao
Copy link
Member Author

@youkaichao Any chance you can add in a quick test for the profiling context manager itself? As an example the one I wrote up here was very simple to do: #11120 (files)

@joerunde that's a great idea! I added it now, PTAL.

Copy link
Collaborator

@joerunde joerunde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty good to me! Thanks for looking at this so thoroughly

@youkaichao
Copy link
Member Author

errors are unrelated, merging

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: Breaking Change in gpu_memory_utilization Behavior in vLLM 0.6.4
7 participants