diff --git a/CHANGELOG.md b/CHANGELOG.md index 39467ff..10c4b4b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,10 @@ *Andrea Sponziello* ### **Copyrigth**: *Tiledesk SRL* +## [2024-07-29] +### 0.2.9 +- add: n_messages on /api/ask to set the maximum number of messages to include + ## [2024-07-27] ### 0.2.8 - add: history on /api/ask diff --git a/pyproject.toml b/pyproject.toml index 264f347..15260a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "tilellm" -version = "0.2.8" +version = "0.2.9" description = "tiledesk for RAG" authors = ["Gianluca Lorenzo "] repository = "https://github.com/Tiledesk/tiledesk-llm" diff --git a/tilellm/controller/controller.py b/tilellm/controller/controller.py index 963c186..8f3a1ce 100644 --- a/tilellm/controller/controller.py +++ b/tilellm/controller/controller.py @@ -199,7 +199,7 @@ async def ask_to_llm(question, chat_model=None): qa_prompt = ChatPromptTemplate.from_messages( [ ("system", question.system_context), - MessagesPlaceholder("chat_history_a"), + MessagesPlaceholder("chat_history_a", n_messages=question.n_messages), ("human", "{input}"), ] ) diff --git a/tilellm/models/item_model.py b/tilellm/models/item_model.py index 1984306..8deb78a 100644 --- a/tilellm/models/item_model.py +++ b/tilellm/models/item_model.py @@ -124,6 +124,7 @@ class QuestionToLLM(BaseModel): debug: bool = Field(default_factory=lambda: False) system_context: str = Field(default="You are a helpful AI bot. Always reply in the same language of the question.") chat_history_dict: Optional[Dict[str, ChatEntry]] = None + n_messages: int = Field(default_factory=lambda: None) @field_validator("temperature") def temperature_range(cls, v): @@ -132,6 +133,13 @@ def temperature_range(cls, v): raise ValueError("Temperature must be between 0.0 and 1.0.") return v + @field_validator("n_messages") + def n_messages_range(cls, v): + """Ensures n_messages is within greater than 0""" + if not v > 0: + raise ValueError("n_messages must be greater than 0") + return v + @field_validator("max_tokens") def max_tokens_range(cls, v): """Ensures max_tokens is a positive integer."""