Skip to content

Commit

Permalink
πŸ™ Add adapters to vllm inference
Browse files Browse the repository at this point in the history
πŸ™ Add adapters to vllm inference
  • Loading branch information
alekseymalakhov11 authored Oct 15, 2024
2 parents 45f259b + b4f6cb7 commit 0f07cd6
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "turbo-alignment"
packages = [
{ include = "turbo_alignment" },
]
version = "0.0.2"
version = "0.0.3"
description = "turbo-alignment repository"
authors = ["T Mega Alignment Team <[email protected]>" ]
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion turbo_alignment/cherry_picks/multimodal.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Iterable

import torch
import wandb
from PIL import Image
from transformers import PreTrainedTokenizerBase

import wandb
from turbo_alignment.cherry_picks.base import CherryPickCallbackBase
from turbo_alignment.dataset.multimodal import InferenceMultimodalDataset
from turbo_alignment.generators.multimodal import MultimodalGenerator
Expand Down
2 changes: 1 addition & 1 deletion turbo_alignment/common/logging/weights_and_biases.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Any

import wandb
from wandb.sdk.lib.disabled import RunDisabled
from wandb.sdk.wandb_run import Run

import wandb
from turbo_alignment.settings.logging.weights_and_biases import WandbSettings


Expand Down
2 changes: 1 addition & 1 deletion turbo_alignment/common/tf/callbacks/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import numpy as np
import pandas as pd
import wandb
from clearml import Task
from transformers import (
TrainerCallback,
Expand All @@ -14,6 +13,7 @@
from wandb.sdk.lib.disabled import RunDisabled
from wandb.sdk.wandb_run import Run

import wandb
from turbo_alignment.common.logging import get_project_logger

logger = get_project_logger()
Expand Down
4 changes: 4 additions & 0 deletions turbo_alignment/generators/vllm_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from transformers import PreTrainedTokenizerBase
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest

from turbo_alignment.dataset.chat import ChatDatasetRecord
from turbo_alignment.generators.base import BaseGenerator
Expand All @@ -23,6 +24,7 @@ def __init__(
tokenizer: PreTrainedTokenizerBase,
batch: int,
return_logits: bool = False,
lora_request: LoRARequest | None = None,
):
model.set_tokenizer(tokenizer)
super().__init__(model, tokenizer, batch=batch)
Expand Down Expand Up @@ -51,6 +53,7 @@ def __init__(
max_tokens=transformers_settings.max_new_tokens,
**beam_search_params,
)
self._lora_request = lora_request

self._return_logits = return_logits

Expand All @@ -62,6 +65,7 @@ def _generate_from_batch(
prompts=None,
prompt_token_ids=input_ids,
sampling_params=self._sampling_params,
lora_request=self._lora_request,
)

outputs = []
Expand Down
16 changes: 15 additions & 1 deletion turbo_alignment/pipelines/inference/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from turbo_alignment.generators.base import BaseGenerator
from turbo_alignment.generators.chat import ChatGenerator
from turbo_alignment.pipelines.inference.base import BaseInferenceStrategy
from turbo_alignment.settings.model import PreTrainedAdaptersModelSettings
from turbo_alignment.settings.pipelines.inference.chat import (
ChatInferenceExperimentSettings,
)
Expand All @@ -27,14 +28,24 @@ def _get_single_inference_settings(

if model_inference_settings.use_vllm:
import vllm
from vllm.lora.request import LoRARequest

from turbo_alignment.generators.vllm_chat import VLLMChatGenerator

lora_request: LoRARequest | None = None
enable_lora: bool = False

if isinstance(model_inference_settings.model_settings, PreTrainedAdaptersModelSettings):
lora_request = LoRARequest('adapter', 1, str(model_inference_settings.model_settings.adapter_path))
enable_lora = True

model = vllm.LLM(
model=model_inference_settings.model_settings.model_path.absolute().as_posix(),
dtype='bfloat16',
tensor_parallel_size=model_inference_settings.tensor_parallel_size,
enable_lora=enable_lora,
)

else:
model = load_model(model_inference_settings.model_settings, tokenizer)
model = (
Expand All @@ -57,7 +68,10 @@ def _get_single_inference_settings(
accelerator=accelerator,
)
if not model_inference_settings.use_vllm
else VLLMChatGenerator(**generator_kwargs)
else VLLMChatGenerator(
**generator_kwargs,
lora_request=lora_request,
)
)

parameters_to_save = {
Expand Down

0 comments on commit 0f07cd6

Please sign in to comment.