Skip to content

Commit

Permalink
Add defaultrole to LLM pipeline, closes #841
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed Dec 19, 2024
1 parent 9bd2da7 commit 952a757
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 4 deletions.
4 changes: 4 additions & 0 deletions docs/pipeline/text/llm.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,14 @@ llm([
{"role": "user", "content": "Answer the following question..."}
])

# Set the default role to user and string inputs are converted to chat messages
llm("Answer the following question...", defaultrole="user")
```

The LLM pipeline automatically detects the underlying LLM framework. This can also be manually set.

[Hugging Face Transformers](https://github.com/huggingface/transformers), [llama.cpp](https://github.com/abetlen/llama-cpp-python) and [hosted API models via LiteLLM](https://github.com/BerriAI/litellm) are all supported by this pipeline.

See the [LiteLLM documentation](https://litellm.vercel.app/docs/providers) for the options available with LiteLLM models. llama.cpp models support both local and remote GGUF paths on the HF Hub.

```python
Expand Down
9 changes: 7 additions & 2 deletions src/python/txtai/pipeline/llm/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self, path=None, template=None, **kwargs):
self.template = template
self.kwargs = kwargs

def __call__(self, text, maxlength, stream, stop, **kwargs):
def __call__(self, text, maxlength, stream, stop, defaultrole, **kwargs):
"""
Generates text. Supports the following input formats:
Expand All @@ -36,6 +36,7 @@ def __call__(self, text, maxlength, stream, stop, **kwargs):
maxlength: maximum sequence length
stream: stream response if True, defaults to False
stop: list of stop strings
defaultrole: default role to apply to text inputs (prompt for raw prompts (default) or user for user chat messages)
kwargs: additional generation keyword arguments
Returns:
Expand All @@ -48,7 +49,11 @@ def __call__(self, text, maxlength, stream, stop, **kwargs):
# Apply template, if necessary
if self.template:
formatter = TemplateFormatter()
texts = [formatter.format(self.template, text=x) for x in texts]
texts = [formatter.format(self.template, text=x) if isinstance(x, str) else x for x in texts]

# Apply default role, if necessary
if defaultrole == "user":
texts = [[{"role": "user", "content": x}] if isinstance(x, str) else x for x in texts]

# Run pipeline
results = self.execute(texts, maxlength, stream, stop, **kwargs)
Expand Down
5 changes: 3 additions & 2 deletions src/python/txtai/pipeline/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, path=None, method=None, **kwargs):
# Generation instance
self.generator = GenerationFactory.create(path, method, **kwargs)

def __call__(self, text, maxlength=512, stream=False, stop=None, **kwargs):
def __call__(self, text, maxlength=512, stream=False, stop=None, defaultrole="prompt", **kwargs):
"""
Generates text. Supports the following input formats:
Expand All @@ -50,6 +50,7 @@ def __call__(self, text, maxlength=512, stream=False, stop=None, **kwargs):
maxlength: maximum sequence length
stream: stream response if True, defaults to False
stop: list of stop strings, defaults to None
defaultrole: default role to apply to text inputs (prompt for raw prompts (default) or user for user chat messages)
kwargs: additional generation keyword arguments
Returns:
Expand All @@ -60,4 +61,4 @@ def __call__(self, text, maxlength=512, stream=False, stop=None, **kwargs):
logger.debug(text)

# Run LLM generation
return self.generator(text, maxlength, stream, stop, **kwargs)
return self.generator(text, maxlength, stream, stop, defaultrole, **kwargs)
1 change: 1 addition & 0 deletions src/python/txtai/pipeline/llm/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ def answers(self, questions, contexts, **kwargs):
Args:
questions: questions
contexts: question context
kwargs: additional keyword arguments to pass to model
Returns:
answers
Expand Down
3 changes: 3 additions & 0 deletions test/python/testpipeline/testllm/testlitellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,8 @@ def testGeneration(self):
model = LLM("huggingface/t5-small", api_base="http://127.0.0.1:8000")
self.assertEqual(model("The sky is"), "blue")

# Test default role
self.assertEqual(model("The sky is", defaultrole="user"), "blue")

# Test streaming
self.assertEqual(" ".join(x for x in model("The sky is", stream=True)), "blue")
3 changes: 3 additions & 0 deletions test/python/testpipeline/testllm/testllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,8 @@ def testGeneration(self):
messages = [{"role": "system", "content": "You are a helpful assistant. You answer math problems."}, {"role": "user", "content": "2+2?"}]
self.assertIsNotNone(model(messages, maxlength=10, seed=0, stop=["."]))

# Test default role
self.assertIsNotNone(model("2 + 2 = ", maxlength=10, seed=0, stop=["."], defaultrole="user"))

# Test streaming
self.assertEqual(" ".join(x for x in model("2 + 2 = ", maxlength=10, stream=True, seed=0, stop=["."]))[0], "4")
8 changes: 8 additions & 0 deletions test/python/testpipeline/testllm/testllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ def testCustomNotFound(self):
with self.assertRaises(ImportError):
LLM("hf-internal-testing/tiny-random-gpt2", method="notfound.generation")

def testDefaultRole(self):
"""
Test default role
"""

model = LLM("hf-internal-testing/tiny-random-LlamaForCausalLM")
self.assertIsNotNone(model("Hello, how are", defaultrole="user"))

def testExternal(self):
"""
Test externally loaded model
Expand Down

0 comments on commit 952a757

Please sign in to comment.