Skip to content

Commit

Permalink
Feat/use vllm server (#40)
Browse files Browse the repository at this point in the history
* feat: Use vllm server

* fix: Use self.model

* debug

* debug

* feat: Try using guided_json

* fix: Use extra_body

* fix: Use json

* fix: Use model_json_schema

* chore: Include response_format

* chore: Logging

* chore: Remove logging

* debug

* fix: Do not break streaming if chunk_str is None

* debug

* feat: Spawn new vLLM server if not already running

* fix: Do not use api_key if running vLLM generator

* fix: vLLM config

* chore: Remove breakpoint

* debug

* debug

* fix: Set server after booting it

* debug

* debug

* fix: Add sleep after server start

* fix: Only require CUDA to start the vLLM inference server, not to use one

* fix: Only set `guided_json` if using vLLM

* tests: vLLM tests

* feat: Add more args to vLLM server

* fix: Typo

* debug

* fix: Up vLLM startup sleep time

* debug

* debug

* debug

* debug

* fix: Add port back in

* fix: Set up self.server in OpenaiGenerator correctly

* debug

* fix: Store config in VllmGenerator

* debug

* feat: Check manually if Uvicorn server has started

* feat: Block stderr when loading tokenizer

* debug

* refactor: Use HiddenPrints

* fix: Block transformers logging

* feat: Add --host back in

* debug

* fix: Add `del self` in `__del__`

* chore: Ignore ResourceWarning in pytest

* tests: Initialise the VllmGenerator fewer times in tests

* fix: Do not hardcode different ports

* tests: Use same VllmGenerator

* tests: Remove validity check test, as it is impossible with VllmGenerator

* tests: Remove random_seed from VllmGenerator config

* docs: Add comments

* fix: Raise ValueError in get_component_by_name if module or class don't exist

* docs: Update coverage badge

* chore: Re-instate pre-commit hook
  • Loading branch information
saattrupdan authored May 22, 2024
1 parent 25e5fa3 commit 5dfeee2
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 145 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
A repository for general-purpose RAG applications.

______________________________________________________________________
[![Code Coverage](https://img.shields.io/badge/Coverage-68%25-yellow.svg)](https://github.com/alexandrainst/ragger/tree/main/tests)
[![Code Coverage](https://img.shields.io/badge/Coverage-70%25-yellow.svg)](https://github.com/alexandrainst/ragger/tree/main/tests)


Developer(s):
Expand Down
7 changes: 5 additions & 2 deletions config/generator/vllm.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
name: vllm
model: ThatsGroes/munin-SkoleGPTOpenOrca-7b-16bit
max_model_len: 10_000
gpu_memory_utilization: 0.95
temperature: 0.0
max_tokens: 256
stream: true
timeout: 60
system_prompt: ${..language.system_prompt}
prompt: ${..language.prompt}
max_model_len: 10_000
gpu_memory_utilization: 0.95
server: null
port: 8000
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ filterwarnings = [
"ignore::DeprecationWarning",
"ignore::PendingDeprecationWarning",
"ignore::ImportWarning",
"ignore::ResourceWarning",
]
log_cli_level = "info"
testpaths = [
Expand Down
207 changes: 90 additions & 117 deletions src/ragger/generator.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Generation of an answer from a query and a list of relevant documents."""

import importlib.util
import json
import logging
import os
import subprocess
import typing
from time import sleep

import torch
from dotenv import load_dotenv
from jinja2 import TemplateError
from omegaconf import DictConfig
from openai import OpenAI, Stream
from openai.types.chat import (
Expand All @@ -18,23 +18,9 @@
from openai.types.chat.completion_create_params import ResponseFormat
from pydantic import ValidationError
from pydantic_core import from_json
from transformers import AutoTokenizer

from .data_models import Document, GeneratedAnswer, Generator
from .utils import clear_memory

if importlib.util.find_spec("vllm") is not None:
from vllm import LLM, SamplingParams
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
JSONLogitsProcessor,
)

try:
from vllm.model_executor.parallel_utils.parallel_state import (
destroy_model_parallel,
)
except ImportError:
from vllm.distributed.parallel_state import destroy_model_parallel


load_dotenv()

Expand All @@ -54,9 +40,24 @@ def __init__(self, config: DictConfig) -> None:
"""
super().__init__(config=config)
logging.getLogger("httpx").setLevel(logging.CRITICAL)
api_key = os.environ[self.config.generator.api_key_variable_name]

if hasattr(config.generator, "api_key_variable_name"):
env_var_name = config.generator.api_key_variable_name
api_key = os.environ[env_var_name].strip('"')
else:
api_key = None

self.server: str | None
if hasattr(config.generator, "server"):
host = config.generator.server
if not host.startswith("http"):
host = f"http://{host}"
self.server = f"{host}:{config.generator.port}/v1"
else:
self.server = None

self.client = OpenAI(
api_key=api_key.strip('"'), timeout=self.config.generator.timeout
base_url=self.server, api_key=api_key, timeout=self.config.generator.timeout
)

def generate(
Expand Down Expand Up @@ -91,6 +92,11 @@ def generate(
),
),
]

extra_body = dict()
if self.config.generator.name == "vllm":
extra_body["guided_json"] = GeneratedAnswer.model_json_schema()

model_output = self.client.chat.completions.create(
messages=messages,
model=self.config.generator.model,
Expand All @@ -99,7 +105,9 @@ def generate(
stream=self.config.generator.stream,
stop=["</answer>"],
response_format=ResponseFormat(type="json_object"),
extra_body=extra_body,
)

if isinstance(model_output, Stream):

def streamer() -> typing.Generator[GeneratedAnswer, None, None]:
Expand All @@ -108,7 +116,7 @@ def streamer() -> typing.Generator[GeneratedAnswer, None, None]:
for chunk in model_output:
chunk_str = chunk.choices[0].delta.content
if chunk_str is None:
break
continue
generated_output += chunk_str
try:
generated_dict = from_json(
Expand Down Expand Up @@ -174,7 +182,7 @@ def streamer() -> typing.Generator[GeneratedAnswer, None, None]:
return generated_obj


class VllmGenerator(Generator):
class VllmGenerator(OpenaiGenerator):
"""A generator that uses a vLLM model to generate answers."""

def __init__(self, config: DictConfig) -> None:
Expand All @@ -184,110 +192,75 @@ def __init__(self, config: DictConfig) -> None:
config:
The Hydra configuration.
"""
super().__init__(config=config)

if not torch.cuda.is_available():
raise RuntimeError(
"The `vLLMGenerator` requires a CUDA-compatible GPU to run. "
"Please ensure that a compatible GPU is available and try again."
)

# We need to remove the model from GPU memory before creating a new one
destroy_model_parallel()
clear_memory()

self.model = LLM(
model=config.generator.model,
gpu_memory_utilization=config.generator.gpu_memory_utilization,
max_model_len=config.generator.max_model_len,
seed=config.random_seed,
tensor_parallel_size=torch.cuda.device_count(),
)
self.tokenizer = self.model.get_tokenizer()
self.logits_processor = JSONLogitsProcessor(
schema=GeneratedAnswer, tokenizer=self.tokenizer, whitespace_pattern=r" ?"
)
self.config = config
logging.getLogger("transformers").setLevel(logging.CRITICAL)

# If an inference server isn't already running then start a new server in a
# background process and store the process ID
self.server_process: subprocess.Popen | None
if config.generator.server is None:
# We can only run the inference server if CUDA is available
if not torch.cuda.is_available():
raise RuntimeError(
"The `vLLMGenerator` requires a CUDA-compatible GPU to run. "
"Please ensure that a compatible GPU is available and try again."
)

config.generator.server = "0.0.0.0"
self.tokenizer = AutoTokenizer.from_pretrained(config.generator.model)
self.server_process = self.start_inference_server()
else:
self.server_process = None

def generate(
self, query: str, documents: list[Document]
) -> GeneratedAnswer | typing.Generator[GeneratedAnswer, None, None]:
"""Generate an answer from a query and relevant documents.
super().__init__(config=config)

Args:
query:
The query to answer.
documents:
The relevant documents.
def start_inference_server(self) -> subprocess.Popen:
"""Start the vLLM inference server.
Returns:
The generated answer.
The inference server process.
"""
logger.info(
f"Generating answer for the query {query!r} and {len(documents):,} "
"documents..."
)

system_prompt = self.config.generator.system_prompt
user_prompt = self.config.generator.prompt.format(
documents=json.dumps([document.model_dump() for document in documents]),
query=query,
)

chat_template_kwargs = dict(
chat_template=self.tokenizer.chat_template,
add_generation_prompt=True,
tokenize=False,
)
try:
prompt = self.tokenizer.apply_chat_template(
conversation=[
dict(role="system", content=system_prompt),
dict(role="user", content=user_prompt),
],
**chat_template_kwargs,
)
except TemplateError:
prompt = self.tokenizer.apply_chat_template(
conversation=[
dict(role="user", content=system_prompt + "\n\n" + user_prompt)
],
**chat_template_kwargs,
)

sampling_params = SamplingParams(
max_tokens=self.config.generator.max_tokens,
temperature=self.config.generator.temperature,
stop=["</answer>"],
logits_processors=[self.logits_processor],
)

model_output = self.model.generate(
prompts=[prompt], sampling_params=sampling_params
logger.info("Starting vLLM server...")

# Start server using the vLLM entrypoint
process = subprocess.Popen(
args=[
"python",
"-m",
"vllm.entrypoints.openai.api_server",
"--model",
self.config.generator.model,
"--max-model-len",
str(self.config.generator.max_model_len),
"--gpu-memory-utilization",
str(self.config.generator.gpu_memory_utilization),
"--chat-template",
self.tokenizer.chat_template,
"--host",
self.config.generator.server,
"--port",
str(self.config.generator.port),
],
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE,
)
generated_output = model_output[0].outputs[0].text

try:
generated_dict = json.loads(generated_output)
except json.JSONDecodeError:
raise ValueError(
f"Could not decode JSON from model output: {generated_output}"
)

try:
generated_obj = GeneratedAnswer.model_validate(generated_dict)
except ValidationError:
raise ValueError(f"Could not validate model output: {generated_dict}")
# Wait for the server to start
stderr = process.stderr
assert stderr is not None
for seconds in range(self.config.generator.timeout):
update = stderr.readline().decode("utf-8")
if "Uvicorn running" in update:
logger.info(f"vLLM server started after {seconds} seconds.")
break
sleep(1)
else:
raise RuntimeError("vLLM server failed to start.")

logger.info(f"Generated answer: {generated_obj.answer!r}")
return generated_obj
return process

def __del__(self) -> None:
"""Clear the GPU memory used by the model, and remove the model itself."""
if hasattr(self, "model"):
del self.model
"""Close down the vLLM server, if we started a new one."""
if self.server_process is not None:
self.server_process.kill()
del self
try:
destroy_model_parallel()
except ImportError:
pass
clear_memory()
38 changes: 35 additions & 3 deletions src/ragger/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import gc
import importlib
import os
import re
import sys
from typing import Type

import torch
Expand Down Expand Up @@ -103,15 +105,27 @@ def get_component_by_name(class_name: str, component_type: str) -> Type:
Returns:
The class.
Raises:
ValueError:
If the module or class cannot be found.
"""
# Get the snake_case and PascalCase version of the class name
full_class_name = f"{class_name}_{component_type}"
name_pascal = snake_to_pascal(snake_string=full_class_name)

# Get the class from the module
# Get the module
module_name = f"ragger.{component_type}"
module = importlib.import_module(name=module_name)
class_: Type = getattr(module, name_pascal)
try:
module = importlib.import_module(name=module_name)
except ModuleNotFoundError:
raise ValueError(f"Module {module_name!r}' not found.")

# Get the class from the module
try:
class_: Type = getattr(module, name_pascal)
except AttributeError:
raise ValueError(f"Class {name_pascal!r} not found in module {module_name!r}.")

return class_

Expand All @@ -138,3 +152,21 @@ def load_ragger_components(config: DictConfig) -> Components:
class_name=config.generator.name, component_type="generator"
),
)


class HiddenPrints:
"""Context manager which removes all terminal output."""

def __enter__(self):
"""Enter the context manager."""
self._original_stdout = sys.stdout
self._original_stderr = sys.stderr
sys.stdout = open(os.devnull, "w")
sys.stderr = open(os.devnull, "w")

def __exit__(self, exc_type, exc_val, exc_tb):
"""Exit the context manager."""
sys.stdout.close()
sys.stderr.close()
sys.stdout = self._original_stdout
sys.stderr = self._original_stderr
7 changes: 5 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,16 @@ def vllm_generator_params(system_prompt, prompt) -> typing.Generator[dict, None,
yield dict(
name="vllm",
model="ThatsGroes/munin-SkoleGPTOpenOrca-7b-16bit",
max_model_len=10_000,
gpu_memory_utilization=0.95,
temperature=0.0,
max_tokens=128,
stream=False,
timeout=60,
system_prompt=system_prompt,
prompt=prompt,
max_model_len=10_000,
gpu_memory_utilization=0.95,
server=None,
port=9999,
)


Expand Down
Loading

0 comments on commit 5dfeee2

Please sign in to comment.