diff --git a/pyproject.toml b/pyproject.toml index 95596fa..febe392 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "turbo-alignment" packages = [ { include = "turbo_alignment" }, ] -version = "0.0.2" +version = "0.0.3" description = "turbo-alignment repository" authors = ["T Mega Alignment Team " ] readme = "README.md" diff --git a/turbo_alignment/cherry_picks/multimodal.py b/turbo_alignment/cherry_picks/multimodal.py index ba1c536..80c6022 100755 --- a/turbo_alignment/cherry_picks/multimodal.py +++ b/turbo_alignment/cherry_picks/multimodal.py @@ -1,10 +1,10 @@ from typing import Iterable import torch -import wandb from PIL import Image from transformers import PreTrainedTokenizerBase +import wandb from turbo_alignment.cherry_picks.base import CherryPickCallbackBase from turbo_alignment.dataset.multimodal import InferenceMultimodalDataset from turbo_alignment.generators.multimodal import MultimodalGenerator diff --git a/turbo_alignment/common/logging/weights_and_biases.py b/turbo_alignment/common/logging/weights_and_biases.py index 3b5bdfd..112e585 100755 --- a/turbo_alignment/common/logging/weights_and_biases.py +++ b/turbo_alignment/common/logging/weights_and_biases.py @@ -1,9 +1,9 @@ from typing import Any -import wandb from wandb.sdk.lib.disabled import RunDisabled from wandb.sdk.wandb_run import Run +import wandb from turbo_alignment.settings.logging.weights_and_biases import WandbSettings diff --git a/turbo_alignment/common/tf/callbacks/logging.py b/turbo_alignment/common/tf/callbacks/logging.py index 646e280..1ad484d 100755 --- a/turbo_alignment/common/tf/callbacks/logging.py +++ b/turbo_alignment/common/tf/callbacks/logging.py @@ -3,7 +3,6 @@ import numpy as np import pandas as pd -import wandb from clearml import Task from transformers import ( TrainerCallback, @@ -14,6 +13,7 @@ from wandb.sdk.lib.disabled import RunDisabled from wandb.sdk.wandb_run import Run +import wandb from turbo_alignment.common.logging import get_project_logger logger = get_project_logger() diff --git a/turbo_alignment/generators/vllm_chat.py b/turbo_alignment/generators/vllm_chat.py index ff52de3..5a89bb1 100755 --- a/turbo_alignment/generators/vllm_chat.py +++ b/turbo_alignment/generators/vllm_chat.py @@ -3,6 +3,7 @@ import torch from transformers import PreTrainedTokenizerBase from vllm import LLM, SamplingParams +from vllm.lora.request import LoRARequest from turbo_alignment.dataset.chat import ChatDatasetRecord from turbo_alignment.generators.base import BaseGenerator @@ -23,6 +24,7 @@ def __init__( tokenizer: PreTrainedTokenizerBase, batch: int, return_logits: bool = False, + lora_request: LoRARequest | None = None, ): model.set_tokenizer(tokenizer) super().__init__(model, tokenizer, batch=batch) @@ -51,6 +53,7 @@ def __init__( max_tokens=transformers_settings.max_new_tokens, **beam_search_params, ) + self._lora_request = lora_request self._return_logits = return_logits @@ -62,6 +65,7 @@ def _generate_from_batch( prompts=None, prompt_token_ids=input_ids, sampling_params=self._sampling_params, + lora_request=self._lora_request, ) outputs = [] diff --git a/turbo_alignment/pipelines/inference/chat.py b/turbo_alignment/pipelines/inference/chat.py index a40df0f..93b975c 100755 --- a/turbo_alignment/pipelines/inference/chat.py +++ b/turbo_alignment/pipelines/inference/chat.py @@ -8,6 +8,7 @@ from turbo_alignment.generators.base import BaseGenerator from turbo_alignment.generators.chat import ChatGenerator from turbo_alignment.pipelines.inference.base import BaseInferenceStrategy +from turbo_alignment.settings.model import PreTrainedAdaptersModelSettings from turbo_alignment.settings.pipelines.inference.chat import ( ChatInferenceExperimentSettings, ) @@ -27,14 +28,24 @@ def _get_single_inference_settings( if model_inference_settings.use_vllm: import vllm + from vllm.lora.request import LoRARequest from turbo_alignment.generators.vllm_chat import VLLMChatGenerator + lora_request: LoRARequest | None = None + enable_lora: bool = False + + if isinstance(model_inference_settings.model_settings, PreTrainedAdaptersModelSettings): + lora_request = LoRARequest('adapter', 1, str(model_inference_settings.model_settings.adapter_path)) + enable_lora = True + model = vllm.LLM( model=model_inference_settings.model_settings.model_path.absolute().as_posix(), dtype='bfloat16', tensor_parallel_size=model_inference_settings.tensor_parallel_size, + enable_lora=enable_lora, ) + else: model = load_model(model_inference_settings.model_settings, tokenizer) model = ( @@ -57,7 +68,10 @@ def _get_single_inference_settings( accelerator=accelerator, ) if not model_inference_settings.use_vllm - else VLLMChatGenerator(**generator_kwargs) + else VLLMChatGenerator( + **generator_kwargs, + lora_request=lora_request, + ) ) parameters_to_save = {