Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/use vllm server #40

Merged
merged 58 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
2d0f5f7
feat: Use vllm server
saattrupdan May 21, 2024
223cd96
fix: Use self.model
saattrupdan May 21, 2024
50c907f
debug
saattrupdan May 21, 2024
d997f28
debug
saattrupdan May 21, 2024
24fd935
feat: Try using guided_json
saattrupdan May 21, 2024
984079f
fix: Use extra_body
saattrupdan May 21, 2024
309baac
fix: Use json
saattrupdan May 21, 2024
35fab6c
fix: Use model_json_schema
saattrupdan May 21, 2024
9b63001
chore: Include response_format
saattrupdan May 21, 2024
0061921
chore: Logging
saattrupdan May 21, 2024
86745cb
chore: Remove logging
saattrupdan May 21, 2024
25c8e07
debug
saattrupdan May 22, 2024
7e0528c
fix: Do not break streaming if chunk_str is None
saattrupdan May 22, 2024
4fef832
debug
saattrupdan May 22, 2024
05a96d3
feat: Spawn new vLLM server if not already running
saattrupdan May 22, 2024
ff9a16e
fix: Do not use api_key if running vLLM generator
saattrupdan May 22, 2024
d5b8f6b
fix: vLLM config
saattrupdan May 22, 2024
c1dc715
chore: Remove breakpoint
saattrupdan May 22, 2024
a73d98d
debug
saattrupdan May 22, 2024
b9e72f0
debug
saattrupdan May 22, 2024
6622b7a
fix: Set server after booting it
saattrupdan May 22, 2024
acde691
debug
saattrupdan May 22, 2024
ef4c33e
debug
saattrupdan May 22, 2024
00b38bd
fix: Add sleep after server start
saattrupdan May 22, 2024
d532c77
fix: Only require CUDA to start the vLLM inference server, not to use…
saattrupdan May 22, 2024
dc0be2c
fix: Only set `guided_json` if using vLLM
saattrupdan May 22, 2024
3fd39c6
tests: vLLM tests
saattrupdan May 22, 2024
ad04ed1
feat: Add more args to vLLM server
saattrupdan May 22, 2024
226a88a
fix: Typo
saattrupdan May 22, 2024
e726226
debug
saattrupdan May 22, 2024
dedb032
fix: Up vLLM startup sleep time
saattrupdan May 22, 2024
598f286
debug
saattrupdan May 22, 2024
36e6d0b
debug
saattrupdan May 22, 2024
192721a
debug
saattrupdan May 22, 2024
89a56c2
debug
saattrupdan May 22, 2024
a2936b6
fix: Add port back in
saattrupdan May 22, 2024
e414942
fix: Set up self.server in OpenaiGenerator correctly
saattrupdan May 22, 2024
6d01292
debug
saattrupdan May 22, 2024
905dd97
fix: Store config in VllmGenerator
saattrupdan May 22, 2024
63cf5c9
debug
saattrupdan May 22, 2024
837ce28
feat: Check manually if Uvicorn server has started
saattrupdan May 22, 2024
14e5d33
feat: Block stderr when loading tokenizer
saattrupdan May 22, 2024
7c1298c
debug
saattrupdan May 22, 2024
bc1641b
refactor: Use HiddenPrints
saattrupdan May 22, 2024
67de367
fix: Block transformers logging
saattrupdan May 22, 2024
aaae8cb
feat: Add --host back in
saattrupdan May 22, 2024
8b67836
debug
saattrupdan May 22, 2024
9a93b7b
fix: Add `del self` in `__del__`
saattrupdan May 22, 2024
3e15ea2
chore: Ignore ResourceWarning in pytest
saattrupdan May 22, 2024
38cb047
tests: Initialise the VllmGenerator fewer times in tests
saattrupdan May 22, 2024
2c3ff56
fix: Do not hardcode different ports
saattrupdan May 22, 2024
9b2fddc
tests: Use same VllmGenerator
saattrupdan May 22, 2024
21abef2
tests: Remove validity check test, as it is impossible with VllmGener…
saattrupdan May 22, 2024
65e538a
tests: Remove random_seed from VllmGenerator config
saattrupdan May 22, 2024
6facb53
docs: Add comments
saattrupdan May 22, 2024
aa64ac4
fix: Raise ValueError in get_component_by_name if module or class don…
saattrupdan May 22, 2024
213950b
docs: Update coverage badge
saattrupdan May 22, 2024
46c61ef
chore: Re-instate pre-commit hook
saattrupdan May 22, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
A repository for general-purpose RAG applications.

______________________________________________________________________
[![Code Coverage](https://img.shields.io/badge/Coverage-68%25-yellow.svg)](https://github.com/alexandrainst/ragger/tree/main/tests)
[![Code Coverage](https://img.shields.io/badge/Coverage-70%25-yellow.svg)](https://github.com/alexandrainst/ragger/tree/main/tests)


Developer(s):
Expand Down
7 changes: 5 additions & 2 deletions config/generator/vllm.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
name: vllm
model: ThatsGroes/munin-SkoleGPTOpenOrca-7b-16bit
max_model_len: 10_000
gpu_memory_utilization: 0.95
temperature: 0.0
max_tokens: 256
stream: true
timeout: 60
system_prompt: ${..language.system_prompt}
prompt: ${..language.prompt}
max_model_len: 10_000
gpu_memory_utilization: 0.95
server: null
port: 8000
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ filterwarnings = [
"ignore::DeprecationWarning",
"ignore::PendingDeprecationWarning",
"ignore::ImportWarning",
"ignore::ResourceWarning",
]
log_cli_level = "info"
testpaths = [
Expand Down
207 changes: 90 additions & 117 deletions src/ragger/generator.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Generation of an answer from a query and a list of relevant documents."""

import importlib.util
import json
import logging
import os
import subprocess
import typing
from time import sleep

import torch
from dotenv import load_dotenv
from jinja2 import TemplateError
from omegaconf import DictConfig
from openai import OpenAI, Stream
from openai.types.chat import (
Expand All @@ -18,23 +18,9 @@
from openai.types.chat.completion_create_params import ResponseFormat
from pydantic import ValidationError
from pydantic_core import from_json
from transformers import AutoTokenizer

from .data_models import Document, GeneratedAnswer, Generator
from .utils import clear_memory

if importlib.util.find_spec("vllm") is not None:
from vllm import LLM, SamplingParams
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
JSONLogitsProcessor,
)

try:
from vllm.model_executor.parallel_utils.parallel_state import (
destroy_model_parallel,
)
except ImportError:
from vllm.distributed.parallel_state import destroy_model_parallel


load_dotenv()

Expand All @@ -54,9 +40,24 @@ def __init__(self, config: DictConfig) -> None:
"""
super().__init__(config=config)
logging.getLogger("httpx").setLevel(logging.CRITICAL)
api_key = os.environ[self.config.generator.api_key_variable_name]

if hasattr(config.generator, "api_key_variable_name"):
env_var_name = config.generator.api_key_variable_name
api_key = os.environ[env_var_name].strip('"')
else:
api_key = None

self.server: str | None
if hasattr(config.generator, "server"):
host = config.generator.server
if not host.startswith("http"):
host = f"http://{host}"
self.server = f"{host}:{config.generator.port}/v1"
else:
self.server = None

self.client = OpenAI(
api_key=api_key.strip('"'), timeout=self.config.generator.timeout
base_url=self.server, api_key=api_key, timeout=self.config.generator.timeout
)

def generate(
Expand Down Expand Up @@ -91,6 +92,11 @@ def generate(
),
),
]

extra_body = dict()
if self.config.generator.name == "vllm":
extra_body["guided_json"] = GeneratedAnswer.model_json_schema()

model_output = self.client.chat.completions.create(
messages=messages,
model=self.config.generator.model,
Expand All @@ -99,7 +105,9 @@ def generate(
stream=self.config.generator.stream,
stop=["</answer>"],
response_format=ResponseFormat(type="json_object"),
extra_body=extra_body,
)

if isinstance(model_output, Stream):

def streamer() -> typing.Generator[GeneratedAnswer, None, None]:
Expand All @@ -108,7 +116,7 @@ def streamer() -> typing.Generator[GeneratedAnswer, None, None]:
for chunk in model_output:
chunk_str = chunk.choices[0].delta.content
if chunk_str is None:
break
continue
generated_output += chunk_str
try:
generated_dict = from_json(
Expand Down Expand Up @@ -174,7 +182,7 @@ def streamer() -> typing.Generator[GeneratedAnswer, None, None]:
return generated_obj


class VllmGenerator(Generator):
class VllmGenerator(OpenaiGenerator):
"""A generator that uses a vLLM model to generate answers."""

def __init__(self, config: DictConfig) -> None:
Expand All @@ -184,110 +192,75 @@ def __init__(self, config: DictConfig) -> None:
config:
The Hydra configuration.
"""
super().__init__(config=config)

if not torch.cuda.is_available():
raise RuntimeError(
"The `vLLMGenerator` requires a CUDA-compatible GPU to run. "
"Please ensure that a compatible GPU is available and try again."
)

# We need to remove the model from GPU memory before creating a new one
destroy_model_parallel()
clear_memory()

self.model = LLM(
model=config.generator.model,
gpu_memory_utilization=config.generator.gpu_memory_utilization,
max_model_len=config.generator.max_model_len,
seed=config.random_seed,
tensor_parallel_size=torch.cuda.device_count(),
)
self.tokenizer = self.model.get_tokenizer()
self.logits_processor = JSONLogitsProcessor(
schema=GeneratedAnswer, tokenizer=self.tokenizer, whitespace_pattern=r" ?"
)
self.config = config
logging.getLogger("transformers").setLevel(logging.CRITICAL)

# If an inference server isn't already running then start a new server in a
# background process and store the process ID
self.server_process: subprocess.Popen | None
if config.generator.server is None:
# We can only run the inference server if CUDA is available
if not torch.cuda.is_available():
raise RuntimeError(
"The `vLLMGenerator` requires a CUDA-compatible GPU to run. "
"Please ensure that a compatible GPU is available and try again."
)

config.generator.server = "0.0.0.0"
self.tokenizer = AutoTokenizer.from_pretrained(config.generator.model)
self.server_process = self.start_inference_server()
else:
self.server_process = None

def generate(
self, query: str, documents: list[Document]
) -> GeneratedAnswer | typing.Generator[GeneratedAnswer, None, None]:
"""Generate an answer from a query and relevant documents.
super().__init__(config=config)

Args:
query:
The query to answer.
documents:
The relevant documents.
def start_inference_server(self) -> subprocess.Popen:
"""Start the vLLM inference server.
Returns:
The generated answer.
The inference server process.
"""
logger.info(
f"Generating answer for the query {query!r} and {len(documents):,} "
"documents..."
)

system_prompt = self.config.generator.system_prompt
user_prompt = self.config.generator.prompt.format(
documents=json.dumps([document.model_dump() for document in documents]),
query=query,
)

chat_template_kwargs = dict(
chat_template=self.tokenizer.chat_template,
add_generation_prompt=True,
tokenize=False,
)
try:
prompt = self.tokenizer.apply_chat_template(
conversation=[
dict(role="system", content=system_prompt),
dict(role="user", content=user_prompt),
],
**chat_template_kwargs,
)
except TemplateError:
prompt = self.tokenizer.apply_chat_template(
conversation=[
dict(role="user", content=system_prompt + "\n\n" + user_prompt)
],
**chat_template_kwargs,
)

sampling_params = SamplingParams(
max_tokens=self.config.generator.max_tokens,
temperature=self.config.generator.temperature,
stop=["</answer>"],
logits_processors=[self.logits_processor],
)

model_output = self.model.generate(
prompts=[prompt], sampling_params=sampling_params
logger.info("Starting vLLM server...")

# Start server using the vLLM entrypoint
process = subprocess.Popen(
args=[
"python",
"-m",
"vllm.entrypoints.openai.api_server",
"--model",
self.config.generator.model,
"--max-model-len",
str(self.config.generator.max_model_len),
"--gpu-memory-utilization",
str(self.config.generator.gpu_memory_utilization),
"--chat-template",
self.tokenizer.chat_template,
"--host",
self.config.generator.server,
"--port",
str(self.config.generator.port),
],
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE,
)
generated_output = model_output[0].outputs[0].text

try:
generated_dict = json.loads(generated_output)
except json.JSONDecodeError:
raise ValueError(
f"Could not decode JSON from model output: {generated_output}"
)

try:
generated_obj = GeneratedAnswer.model_validate(generated_dict)
except ValidationError:
raise ValueError(f"Could not validate model output: {generated_dict}")
# Wait for the server to start
stderr = process.stderr
assert stderr is not None
for seconds in range(self.config.generator.timeout):
update = stderr.readline().decode("utf-8")
if "Uvicorn running" in update:
logger.info(f"vLLM server started after {seconds} seconds.")
break
sleep(1)
else:
raise RuntimeError("vLLM server failed to start.")

logger.info(f"Generated answer: {generated_obj.answer!r}")
return generated_obj
return process

def __del__(self) -> None:
"""Clear the GPU memory used by the model, and remove the model itself."""
if hasattr(self, "model"):
del self.model
"""Close down the vLLM server, if we started a new one."""
if self.server_process is not None:
self.server_process.kill()
del self
try:
destroy_model_parallel()
except ImportError:
pass
clear_memory()
38 changes: 35 additions & 3 deletions src/ragger/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import gc
import importlib
import os
import re
import sys
from typing import Type

import torch
Expand Down Expand Up @@ -103,15 +105,27 @@ def get_component_by_name(class_name: str, component_type: str) -> Type:
Returns:
The class.
Raises:
ValueError:
If the module or class cannot be found.
"""
# Get the snake_case and PascalCase version of the class name
full_class_name = f"{class_name}_{component_type}"
name_pascal = snake_to_pascal(snake_string=full_class_name)

# Get the class from the module
# Get the module
module_name = f"ragger.{component_type}"
module = importlib.import_module(name=module_name)
class_: Type = getattr(module, name_pascal)
try:
module = importlib.import_module(name=module_name)
except ModuleNotFoundError:
raise ValueError(f"Module {module_name!r}' not found.")

# Get the class from the module
try:
class_: Type = getattr(module, name_pascal)
except AttributeError:
raise ValueError(f"Class {name_pascal!r} not found in module {module_name!r}.")

return class_

Expand All @@ -138,3 +152,21 @@ def load_ragger_components(config: DictConfig) -> Components:
class_name=config.generator.name, component_type="generator"
),
)


class HiddenPrints:
"""Context manager which removes all terminal output."""

def __enter__(self):
"""Enter the context manager."""
self._original_stdout = sys.stdout
self._original_stderr = sys.stderr
sys.stdout = open(os.devnull, "w")
sys.stderr = open(os.devnull, "w")

def __exit__(self, exc_type, exc_val, exc_tb):
"""Exit the context manager."""
sys.stdout.close()
sys.stderr.close()
sys.stdout = self._original_stdout
sys.stderr = self._original_stderr
7 changes: 5 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,16 @@ def vllm_generator_params(system_prompt, prompt) -> typing.Generator[dict, None,
yield dict(
name="vllm",
model="ThatsGroes/munin-SkoleGPTOpenOrca-7b-16bit",
max_model_len=10_000,
gpu_memory_utilization=0.95,
temperature=0.0,
max_tokens=128,
stream=False,
timeout=60,
system_prompt=system_prompt,
prompt=prompt,
max_model_len=10_000,
gpu_memory_utilization=0.95,
server=None,
port=9999,
)


Expand Down
Loading