diff --git a/chat_server/blueprints/personas.py b/chat_server/blueprints/personas.py index 6c0c8abd..5ca74785 100644 --- a/chat_server/blueprints/personas.py +++ b/chat_server/blueprints/personas.py @@ -25,6 +25,7 @@ # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + from fastapi import APIRouter from starlette.responses import JSONResponse @@ -47,7 +48,7 @@ PersonaData, ) from chat_server.server_utils.api_dependencies.validators import permitted_access - +from chat_server.server_utils.persona_utils import notify_personas_changed, list_personas as _list_personas from utils.database_utils.mongo_utils import MongoFilter, MongoLogicalOperators from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI @@ -60,34 +61,10 @@ @router.get("/list") async def list_personas( current_user: CurrentUserData, - request_model: ListPersonasQueryModel = permitted_access(ListPersonasQueryModel), -): + request_model: ListPersonasQueryModel = permitted_access(ListPersonasQueryModel) +) -> JSONResponse: """Lists personas matching query params""" - filters = [] - if request_model.llms: - filters.append( - MongoFilter( - key="supported_llms", - value=request_model.llms, - logical_operator=MongoLogicalOperators.ALL, - ) - ) - if request_model.user_id and request_model.user_id != "*": - filters.append(MongoFilter(key="user_id", value=request_model.user_id)) - else: - user_filter = [{"user_id": None}, {"user_id": current_user.user_id}] - filters.append( - MongoFilter(value=user_filter, logical_operator=MongoLogicalOperators.OR) - ) - if request_model.only_enabled: - filters.append(MongoFilter(key="enabled", value=True)) - items = MongoDocumentsAPI.PERSONAS.list_items( - filters=filters, result_as_cursor=False - ) - for item in items: - item["id"] = item.pop("_id") - item["enabled"] = item.get("enabled", False) - return JSONResponse(content={"items": items}) + return await _list_personas(current_user, request_model) @router.get("/get/{persona_id}") @@ -112,6 +89,8 @@ async def add_persona( if existing_model: raise DuplicatedItemException MongoDocumentsAPI.PERSONAS.add_item(data=request_model.model_dump()) + # TODO: This will not scale to multiple server instances + await notify_personas_changed(request_model.supported_llms) return KlatAPIResponse.OK @@ -131,6 +110,8 @@ async def set_persona( MongoDocumentsAPI.PERSONAS.update_item( filters=mongo_filter, data=request_model.model_dump() ) + # TODO: This will not scale to multiple server instances + await notify_personas_changed(request_model.supported_llms) return KlatAPIResponse.OK @@ -140,6 +121,8 @@ async def delete_persona( ): """Deletes persona""" MongoDocumentsAPI.PERSONAS.delete_item(item_id=request_model.persona_id) + # TODO: This will not scale to multiple server instances + await notify_personas_changed() return KlatAPIResponse.OK @@ -157,4 +140,6 @@ async def toggle_persona_state( ) if updated_data.matched_count == 0: raise ItemNotFoundException + # TODO: This will not scale to multiple server instances + await notify_personas_changed() return KlatAPIResponse.OK diff --git a/chat_server/server_utils/persona_utils.py b/chat_server/server_utils/persona_utils.py new file mode 100644 index 00000000..639130c7 --- /dev/null +++ b/chat_server/server_utils/persona_utils.py @@ -0,0 +1,100 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +from time import time + +from typing import Optional, List +from asyncio import Lock + +from starlette.responses import JSONResponse + +from chat_server.server_utils.api_dependencies import ListPersonasQueryModel, CurrentUserData +from chat_server.sio.server import sio +from utils.database_utils.mongo_utils import MongoFilter, MongoLogicalOperators +from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI + +_LOCK = Lock() + + +async def notify_personas_changed(supported_llms: Optional[List[str]] = None): + """ + Emit an SIO event for each LLM affected by a persona change. This sends a + complete set of personas rather than only the changed one to prevent sync + conflicts and simplify client-side logic. + :param supported_llms: List of LLM names affected by a transaction. If None, + then updates all LLMs listed in database configuration + """ + async with _LOCK: + resp = await list_personas(None, + ListPersonasQueryModel(only_enabled=True)) + update_time = round(time()) + enabled_personas = json.loads(resp.body.decode()) + valid_personas = {} + if supported_llms: + # Only broadcast updates for LLMs affected by an insert/change request + for llm in supported_llms: + valid_personas[llm] = [per for per in enabled_personas["items"] if + llm in per["supported_llms"]] + else: + # Delete request does not have LLM context, update everything + for persona in enabled_personas["items"]: + for llm in persona["supported_llms"]: + valid_personas.setdefault(llm, []) + valid_personas[llm].append(persona) + sio.emit("configured_personas_changed", {"personas": valid_personas, + "update_time": update_time}) + + +async def list_personas(current_user: CurrentUserData, + request_model: ListPersonasQueryModel) -> JSONResponse: + filters = [] + if request_model.llms: + filters.append( + MongoFilter( + key="supported_llms", + value=request_model.llms, + logical_operator=MongoLogicalOperators.ALL, + ) + ) + if request_model.user_id and request_model.user_id != "*": + filters.append(MongoFilter(key="user_id", value=request_model.user_id)) + elif current_user: + user_filter = [{"user_id": None}, {"user_id": current_user.user_id}] + filters.append( + MongoFilter(value=user_filter, logical_operator=MongoLogicalOperators.OR) + ) + if request_model.only_enabled: + filters.append(MongoFilter(key="enabled", value=True)) + items = MongoDocumentsAPI.PERSONAS.list_items( + filters=filters, result_as_cursor=False + ) + for item in items: + item["id"] = item.pop("_id") + item["enabled"] = item.get("enabled", False) + return JSONResponse(content={"items": items}) \ No newline at end of file diff --git a/services/klatchat_observer/controller.py b/services/klatchat_observer/controller.py index 3c38bb37..1d1c8a60 100644 --- a/services/klatchat_observer/controller.py +++ b/services/klatchat_observer/controller.py @@ -29,6 +29,7 @@ import json import re import time +from typing import Optional from queue import Queue import cachetools.func @@ -338,6 +339,8 @@ def register_sio_handlers(self): handler=self.request_revoke_submind_ban_from_conversation, ) self._sio.on("auth_expired", handler=self._handle_auth_expired) + self._sio.on("configured_personas_changed", + handler=self._handle_personas_changed) def connect_sio(self): """ @@ -445,6 +448,9 @@ def get_neon_request_structure(msg_data: dict): return request_dict def _handle_neon_recipient(self, recipient_data: dict, msg_data: dict): + """ + Handle a chat message intended for Neon. + """ msg_data.setdefault("message_body", msg_data.pop("messageText", "")) msg_data.setdefault("message_id", msg_data.pop("messageID", "")) recipient_data.setdefault("context", {}) @@ -825,6 +831,9 @@ def on_subminds_state(self, body: dict): @create_mq_callback() def on_get_configured_personas(self, body: dict): + """ + Handles requests to get all defined personas for a specific LLM service + """ response_data = self._fetch_persona_api(user_id=body.get("user_id")) response_data["items"] = [ item @@ -842,24 +851,41 @@ def on_get_configured_personas(self, body: dict): ) @cachetools.func.ttl_cache(ttl=15) - def _fetch_persona_api(self, user_id: str) -> dict: + def _fetch_persona_api(self, user_id: Optional[str]) -> dict: query_string = self._build_persona_api_query(user_id=user_id) url = f"{self.server_url}/personas/list?{query_string}" try: response = self._fetch_klat_server(url=url) data = response.json() + data['update_time'] = time.time() self._refresh_default_persona_llms(data=data) except KlatAPIAuthorizationError: LOG.error(f"Failed to fetch personas from {url = }") data = {"items": []} return data + def _handle_personas_changed(self, data: dict): + """ + SIO handler called when configured personas are modified. This emits an + MQ message to allow any connected listeners to maintain a set of known + personas. + """ + for llm, personas in data["personas"].items(): + self.send_message( + request_data={ + "items": personas, + "update_time": data.get("update_time") or time.time()}, + vhost=self.get_vhost("llm"), + queue=f"{llm}_personas_input", + expiration=5000, + ) + def _refresh_default_persona_llms(self, data): for item in data["items"]: if default_llm := item.get("default_llm"): self.default_persona_llms[item["id"]] = item["id"] + "_" + default_llm - def _build_persona_api_query(self, user_id: str) -> str: + def _build_persona_api_query(self, user_id: Optional[str]) -> str: url_query_params = f"only_enabled=true" if user_id: url_query_params += f"&user_id={user_id}"