-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ποΈ vllm for Online DPO #2558
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
we get different results with vllm. probably linked to sampling param. investigating |
CI fails because in the latest transformers version release yesterday, transformers uses a python 3.10+ syntax ( |
max_length (`int`, *optional*, defaults to `256`): | ||
Maximum total length of the sequence (prompt + completion) used to compute log probabilities. If the | ||
sequence exceeds this limit, the leftmost tokens will be truncated to preserve as much of the completion as | ||
possible. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To avoid OOM for long prompts
trl/trainer/online_dpo_trainer.py
Outdated
# However, at this stage, the optimizer's weights are not yet loaded onto the GPU; they will be loaded | ||
# after the first optimizer step and remain in GPU memory throughout training. So we must reserve enough | ||
# space for them. Setting gpu_memory_utilization to 0.6 seems to work well in practice. | ||
self.llm = LLM(model=model.name_or_path, gpu_memory_utilization=0.55, dtype=torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dtype
and gpu_memory_utilization
are hardcoded but we can still make them as arg in the future.
See here why using torch.float32
is important.
β¦nto vllm-onlinedpo
β¦nto vllm-onlinedpo
Awesome addition. If I understand correctly there are now two instances of the model, so this will only work with DDP and small models where memory capacity is not an issue? |
Correct: one trained model and one generation model. Note that the weights of these two models are always the same. I wonder if we could share them instead of duplicating them? @hmellor any idea?
Yes, parallelism (not supported yet) should make it possible to relax this constraint. |
vllm-project/vllm#10353 is probably the closest we can currently get to directly accessing vLLM's model, but this PR is expected to be superseded by work done in the V1 engine re-architecture that is ongoing (planned beta-release any day now, on by default by end of January). vllm-project/vllm#12084 introduces some RLHF features to vLLM, but it follows the OpenRLHF model where training and inference processes live on different GPUs. |
What does this PR do?
Use vLLM for generation. 2.2x faster ππ
Demo:
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.