diff --git a/README.md b/README.md index de647d4b..a4087550 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ A repository for general-purpose RAG applications. ______________________________________________________________________ -[![Code Coverage](https://img.shields.io/badge/Coverage-68%25-yellow.svg)](https://github.com/alexandrainst/ragger/tree/main/tests) +[![Code Coverage](https://img.shields.io/badge/Coverage-70%25-yellow.svg)](https://github.com/alexandrainst/ragger/tree/main/tests) Developer(s): diff --git a/config/generator/vllm.yaml b/config/generator/vllm.yaml index 43f702b5..8052e172 100644 --- a/config/generator/vllm.yaml +++ b/config/generator/vllm.yaml @@ -1,9 +1,12 @@ name: vllm model: ThatsGroes/munin-SkoleGPTOpenOrca-7b-16bit -max_model_len: 10_000 -gpu_memory_utilization: 0.95 temperature: 0.0 max_tokens: 256 stream: true +timeout: 60 system_prompt: ${..language.system_prompt} prompt: ${..language.prompt} +max_model_len: 10_000 +gpu_memory_utilization: 0.95 +server: null +port: 8000 diff --git a/pyproject.toml b/pyproject.toml index 41e2b580..e69d1271 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,6 +95,7 @@ filterwarnings = [ "ignore::DeprecationWarning", "ignore::PendingDeprecationWarning", "ignore::ImportWarning", + "ignore::ResourceWarning", ] log_cli_level = "info" testpaths = [ diff --git a/src/ragger/generator.py b/src/ragger/generator.py index bd0e3c93..a613d4d0 100644 --- a/src/ragger/generator.py +++ b/src/ragger/generator.py @@ -1,14 +1,14 @@ """Generation of an answer from a query and a list of relevant documents.""" -import importlib.util import json import logging import os +import subprocess import typing +from time import sleep import torch from dotenv import load_dotenv -from jinja2 import TemplateError from omegaconf import DictConfig from openai import OpenAI, Stream from openai.types.chat import ( @@ -18,23 +18,9 @@ from openai.types.chat.completion_create_params import ResponseFormat from pydantic import ValidationError from pydantic_core import from_json +from transformers import AutoTokenizer from .data_models import Document, GeneratedAnswer, Generator -from .utils import clear_memory - -if importlib.util.find_spec("vllm") is not None: - from vllm import LLM, SamplingParams - from vllm.model_executor.guided_decoding.outlines_logits_processors import ( - JSONLogitsProcessor, - ) - - try: - from vllm.model_executor.parallel_utils.parallel_state import ( - destroy_model_parallel, - ) - except ImportError: - from vllm.distributed.parallel_state import destroy_model_parallel - load_dotenv() @@ -54,9 +40,24 @@ def __init__(self, config: DictConfig) -> None: """ super().__init__(config=config) logging.getLogger("httpx").setLevel(logging.CRITICAL) - api_key = os.environ[self.config.generator.api_key_variable_name] + + if hasattr(config.generator, "api_key_variable_name"): + env_var_name = config.generator.api_key_variable_name + api_key = os.environ[env_var_name].strip('"') + else: + api_key = None + + self.server: str | None + if hasattr(config.generator, "server"): + host = config.generator.server + if not host.startswith("http"): + host = f"http://{host}" + self.server = f"{host}:{config.generator.port}/v1" + else: + self.server = None + self.client = OpenAI( - api_key=api_key.strip('"'), timeout=self.config.generator.timeout + base_url=self.server, api_key=api_key, timeout=self.config.generator.timeout ) def generate( @@ -91,6 +92,11 @@ def generate( ), ), ] + + extra_body = dict() + if self.config.generator.name == "vllm": + extra_body["guided_json"] = GeneratedAnswer.model_json_schema() + model_output = self.client.chat.completions.create( messages=messages, model=self.config.generator.model, @@ -99,7 +105,9 @@ def generate( stream=self.config.generator.stream, stop=[""], response_format=ResponseFormat(type="json_object"), + extra_body=extra_body, ) + if isinstance(model_output, Stream): def streamer() -> typing.Generator[GeneratedAnswer, None, None]: @@ -108,7 +116,7 @@ def streamer() -> typing.Generator[GeneratedAnswer, None, None]: for chunk in model_output: chunk_str = chunk.choices[0].delta.content if chunk_str is None: - break + continue generated_output += chunk_str try: generated_dict = from_json( @@ -174,7 +182,7 @@ def streamer() -> typing.Generator[GeneratedAnswer, None, None]: return generated_obj -class VllmGenerator(Generator): +class VllmGenerator(OpenaiGenerator): """A generator that uses a vLLM model to generate answers.""" def __init__(self, config: DictConfig) -> None: @@ -184,110 +192,75 @@ def __init__(self, config: DictConfig) -> None: config: The Hydra configuration. """ - super().__init__(config=config) - - if not torch.cuda.is_available(): - raise RuntimeError( - "The `vLLMGenerator` requires a CUDA-compatible GPU to run. " - "Please ensure that a compatible GPU is available and try again." - ) - - # We need to remove the model from GPU memory before creating a new one - destroy_model_parallel() - clear_memory() - - self.model = LLM( - model=config.generator.model, - gpu_memory_utilization=config.generator.gpu_memory_utilization, - max_model_len=config.generator.max_model_len, - seed=config.random_seed, - tensor_parallel_size=torch.cuda.device_count(), - ) - self.tokenizer = self.model.get_tokenizer() - self.logits_processor = JSONLogitsProcessor( - schema=GeneratedAnswer, tokenizer=self.tokenizer, whitespace_pattern=r" ?" - ) + self.config = config + logging.getLogger("transformers").setLevel(logging.CRITICAL) + + # If an inference server isn't already running then start a new server in a + # background process and store the process ID + self.server_process: subprocess.Popen | None + if config.generator.server is None: + # We can only run the inference server if CUDA is available + if not torch.cuda.is_available(): + raise RuntimeError( + "The `vLLMGenerator` requires a CUDA-compatible GPU to run. " + "Please ensure that a compatible GPU is available and try again." + ) + + config.generator.server = "0.0.0.0" + self.tokenizer = AutoTokenizer.from_pretrained(config.generator.model) + self.server_process = self.start_inference_server() + else: + self.server_process = None - def generate( - self, query: str, documents: list[Document] - ) -> GeneratedAnswer | typing.Generator[GeneratedAnswer, None, None]: - """Generate an answer from a query and relevant documents. + super().__init__(config=config) - Args: - query: - The query to answer. - documents: - The relevant documents. + def start_inference_server(self) -> subprocess.Popen: + """Start the vLLM inference server. Returns: - The generated answer. + The inference server process. """ - logger.info( - f"Generating answer for the query {query!r} and {len(documents):,} " - "documents..." - ) - - system_prompt = self.config.generator.system_prompt - user_prompt = self.config.generator.prompt.format( - documents=json.dumps([document.model_dump() for document in documents]), - query=query, - ) - - chat_template_kwargs = dict( - chat_template=self.tokenizer.chat_template, - add_generation_prompt=True, - tokenize=False, - ) - try: - prompt = self.tokenizer.apply_chat_template( - conversation=[ - dict(role="system", content=system_prompt), - dict(role="user", content=user_prompt), - ], - **chat_template_kwargs, - ) - except TemplateError: - prompt = self.tokenizer.apply_chat_template( - conversation=[ - dict(role="user", content=system_prompt + "\n\n" + user_prompt) - ], - **chat_template_kwargs, - ) - - sampling_params = SamplingParams( - max_tokens=self.config.generator.max_tokens, - temperature=self.config.generator.temperature, - stop=[""], - logits_processors=[self.logits_processor], - ) - - model_output = self.model.generate( - prompts=[prompt], sampling_params=sampling_params + logger.info("Starting vLLM server...") + + # Start server using the vLLM entrypoint + process = subprocess.Popen( + args=[ + "python", + "-m", + "vllm.entrypoints.openai.api_server", + "--model", + self.config.generator.model, + "--max-model-len", + str(self.config.generator.max_model_len), + "--gpu-memory-utilization", + str(self.config.generator.gpu_memory_utilization), + "--chat-template", + self.tokenizer.chat_template, + "--host", + self.config.generator.server, + "--port", + str(self.config.generator.port), + ], + stdout=subprocess.DEVNULL, + stderr=subprocess.PIPE, ) - generated_output = model_output[0].outputs[0].text - try: - generated_dict = json.loads(generated_output) - except json.JSONDecodeError: - raise ValueError( - f"Could not decode JSON from model output: {generated_output}" - ) - - try: - generated_obj = GeneratedAnswer.model_validate(generated_dict) - except ValidationError: - raise ValueError(f"Could not validate model output: {generated_dict}") + # Wait for the server to start + stderr = process.stderr + assert stderr is not None + for seconds in range(self.config.generator.timeout): + update = stderr.readline().decode("utf-8") + if "Uvicorn running" in update: + logger.info(f"vLLM server started after {seconds} seconds.") + break + sleep(1) + else: + raise RuntimeError("vLLM server failed to start.") - logger.info(f"Generated answer: {generated_obj.answer!r}") - return generated_obj + return process def __del__(self) -> None: - """Clear the GPU memory used by the model, and remove the model itself.""" - if hasattr(self, "model"): - del self.model + """Close down the vLLM server, if we started a new one.""" + if self.server_process is not None: + self.server_process.kill() del self - try: - destroy_model_parallel() - except ImportError: - pass - clear_memory() diff --git a/src/ragger/utils.py b/src/ragger/utils.py index 5b4c4225..d58aad52 100644 --- a/src/ragger/utils.py +++ b/src/ragger/utils.py @@ -2,7 +2,9 @@ import gc import importlib +import os import re +import sys from typing import Type import torch @@ -103,15 +105,27 @@ def get_component_by_name(class_name: str, component_type: str) -> Type: Returns: The class. + + Raises: + ValueError: + If the module or class cannot be found. """ # Get the snake_case and PascalCase version of the class name full_class_name = f"{class_name}_{component_type}" name_pascal = snake_to_pascal(snake_string=full_class_name) - # Get the class from the module + # Get the module module_name = f"ragger.{component_type}" - module = importlib.import_module(name=module_name) - class_: Type = getattr(module, name_pascal) + try: + module = importlib.import_module(name=module_name) + except ModuleNotFoundError: + raise ValueError(f"Module {module_name!r}' not found.") + + # Get the class from the module + try: + class_: Type = getattr(module, name_pascal) + except AttributeError: + raise ValueError(f"Class {name_pascal!r} not found in module {module_name!r}.") return class_ @@ -138,3 +152,21 @@ def load_ragger_components(config: DictConfig) -> Components: class_name=config.generator.name, component_type="generator" ), ) + + +class HiddenPrints: + """Context manager which removes all terminal output.""" + + def __enter__(self): + """Enter the context manager.""" + self._original_stdout = sys.stdout + self._original_stderr = sys.stderr + sys.stdout = open(os.devnull, "w") + sys.stderr = open(os.devnull, "w") + + def __exit__(self, exc_type, exc_val, exc_tb): + """Exit the context manager.""" + sys.stdout.close() + sys.stderr.close() + sys.stdout = self._original_stdout + sys.stderr = self._original_stderr diff --git a/tests/conftest.py b/tests/conftest.py index 3f8d4f50..c6f5a65a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -114,13 +114,16 @@ def vllm_generator_params(system_prompt, prompt) -> typing.Generator[dict, None, yield dict( name="vllm", model="ThatsGroes/munin-SkoleGPTOpenOrca-7b-16bit", - max_model_len=10_000, - gpu_memory_utilization=0.95, temperature=0.0, max_tokens=128, stream=False, + timeout=60, system_prompt=system_prompt, prompt=prompt, + max_model_len=10_000, + gpu_memory_utilization=0.95, + server=None, + port=9999, ) diff --git a/tests/test_generator.py b/tests/test_generator.py index 2d93da1d..4ec769ff 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -1,6 +1,7 @@ """Unit tests for the `generator` module.""" import typing +from copy import deepcopy import pytest import torch @@ -71,7 +72,7 @@ def test_error_if_not_json(self, config, query, documents) -> None: config.generator.max_tokens = old_max_tokens def test_error_if_not_valid_types(self, config, query, documents) -> None: - """Test that the generator raises an error if the output is not JSON.""" + """Test that the generator raises an error if the JSON isn't valid.""" generator = OpenaiGenerator(config=config) bad_prompt = 'Inkludér kilderne i key\'en "kilder" i stedet for "sources".' with pytest.raises(ValueError): @@ -83,39 +84,33 @@ class TestVllmGenerator: """Tests for the `VllmGenerator` class.""" @pytest.fixture(scope="class") - def config(self, vllm_generator_params) -> typing.Generator[DictConfig, None, None]: + def generator( + self, vllm_generator_params + ) -> typing.Generator[VllmGenerator, None, None]: """Initialise a configuration for testing.""" - yield DictConfig(dict(random_seed=703, generator=vllm_generator_params)) + config = DictConfig(dict(generator=vllm_generator_params)) + yield VllmGenerator(config=config) def test_is_generator(self) -> None: """Test that the VllmGenerator is a Generator.""" assert issubclass(VllmGenerator, Generator) - def test_initialisation(self, config) -> None: + def test_initialisation(self, generator) -> None: """Test that the generator is initialised correctly.""" - generator = VllmGenerator(config=config) assert generator - del generator - def test_generate(self, config, query, documents) -> None: + def test_generate(self, generator, query, documents) -> None: """Test that the generator generates an answer.""" - generator = VllmGenerator(config=config) answer = generator.generate(query=query, documents=documents) expected = GeneratedAnswer(answer="Uerop", sources=["2"]) assert answer == expected - def test_error_if_not_json(self, config, query, documents) -> None: + def test_error_if_not_json(self, generator, query, documents) -> None: """Test that the generator raises an error if the output is not JSON.""" - old_max_tokens = config.generator.max_tokens - config.generator.max_tokens = 1 - generator = VllmGenerator(config=config) + old_config = generator.config + config_copy = deepcopy(old_config) + config_copy.generator.max_tokens = 1 + generator.config = config_copy with pytest.raises(ValueError): generator.generate(query=query, documents=documents) - config.generator.max_tokens = old_max_tokens - - def test_error_if_not_valid_types(self, config, query, documents) -> None: - """Test that the generator raises an error if the output is not JSON.""" - generator = VllmGenerator(config=config) - bad_prompt = 'Inkludér kilderne i key\'en "kilder" i stedet for "sources".' - with pytest.raises(ValueError): - generator.generate(query=f"{query}\n{bad_prompt}", documents=documents) + generator.config = old_config