Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle Klat persona update events #10

Draft
wants to merge 2 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion neon_llm_core/rmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from abc import abstractmethod, ABC
from threading import Thread
from threading import Thread, Lock
from time import time

from neon_mq_connector.connector import MQConnector
from neon_mq_connector.utils.rabbit_utils import create_mq_callback
Expand All @@ -52,6 +53,8 @@ def __init__(self):
self.register_consumers()
self._model = None
self._bots = list()
self._persona_update_lock = Lock()
self._last_persona_update = time()
self._personas_provider = PersonasProvider(service_name=self.name,
ovos_config=self.ovos_config)

Expand All @@ -72,6 +75,11 @@ def register_consumers(self):
queue=self.queue_opinion,
callback=self.handle_opinion_request,
on_error=self.default_error_handler,)
self.register_consumer(name=f'neon_llm_{self.name}_personas',
vhost=self.vhost,
queue=self.queue_personas,
callback=self.handle_new_personas,
on_error=self.default_error_handler)

@property
@abstractmethod
Expand All @@ -97,6 +105,10 @@ def queue_score(self):
def queue_opinion(self):
return f"{self.name}_discussion_input"

@property
def queue_personas(self):
return f"{self.name}_personas_input"

@property
@abstractmethod
def model(self) -> NeonLLM:
Expand All @@ -113,6 +125,20 @@ def handle_request(self, body: dict):
Thread(target=self._handle_request_async, args=(body,),
daemon=True).start()

@create_mq_callback()
def handle_new_personas(self, body: dict):
"""
Handles an emitted message from the server containing personas defined
for this LLM
:param body: MQ message body containing persona definitions
"""
if body.get("update_time", time()) <= self._last_persona_update:
LOG.info("Skipping update that is older than last update")
return
with self._persona_update_lock:
self._personas_provider.parse_persona_response(body)
self._last_persona_update = body.get("update_time") or time()

def _handle_request_async(self, request: dict):
message_id = request["message_id"]
routing_key = request["routing_key"]
Expand Down
53 changes: 41 additions & 12 deletions neon_llm_core/utils/personas/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os
from time import time
from typing import List

from neon_mq_connector.utils import RepeatingTimer
from neon_mq_connector.utils.client_utils import send_mq_request
Expand All @@ -36,9 +37,13 @@


class PersonasProvider:
"""
Manages personas defined via Klat. Each LLM that connects to the MQ bus will
include an instance of this object to track changes to personas.
"""

PERSONA_STATE_TTL = int(os.getenv("PERSONA_STATE_TTL", 15 * 60))
PERSONA_SYNC_INTERVAL = int(os.getenv("PERSONA_SYNC_INTERVAL", 5 * 60))
PERSONA_SYNC_INTERVAL = int(os.getenv("PERSONA_SYNC_INTERVAL", 0))
GET_CONFIGURED_PERSONAS_QUEUE = "get_configured_personas"

def __init__(self, service_name: str, ovos_config: dict):
Expand All @@ -50,7 +55,7 @@ def __init__(self, service_name: str, ovos_config: dict):
self._persona_sync_thread = None

@property
def persona_sync_thread(self):
def persona_sync_thread(self) -> RepeatingTimer:
"""Creates new synchronization thread which fetches Klat personas"""
if not (isinstance(self._persona_sync_thread, RepeatingTimer) and
self._persona_sync_thread.is_alive()):
Expand All @@ -60,11 +65,11 @@ def persona_sync_thread(self):
return self._persona_sync_thread

@property
def personas(self):
def personas(self) -> List[PersonaModel]:
return self._personas

@personas.setter
def personas(self, data):
def personas(self, data: List[PersonaModel]):
LOG.debug(f'Setting personas={data}')
if self._should_reset_personas(data=data):
LOG.warning(f'Persona state TTL expired, resetting personas config')
Expand All @@ -74,18 +79,34 @@ def personas(self, data):
self._personas = data
self._persona_handlers_state.clean_up_personas(ignore_items=self._personas)

def _should_reset_personas(self, data) -> bool:
def _should_reset_personas(self, data: List[PersonaModel]) -> bool:
"""
Checks if personas should be re-initialized after setting a new value
for personas.
:param data: requested list of personas
:return: True if requested `data` should be ignored and personas
reloaded from config, False if requested `data` should be used
directly
"""
return (not (self._persona_last_sync == 0 and data)
and int(time()) - self._persona_last_sync > self.PERSONA_STATE_TTL)

def _fetch_persona_config(self):
response = send_mq_request(vhost=LLM_VHOST,
request_data={"service_name": self.service_name},
target_queue=PersonasProvider.GET_CONFIGURED_PERSONAS_QUEUE,
timeout=60)
if 'items' in response:
"""
Get personas from a provider on the MQ bus and update the internal
`personas` reference.
"""
response = send_mq_request(
vhost=LLM_VHOST,
request_data={"service_name": self.service_name},
target_queue=PersonasProvider.GET_CONFIGURED_PERSONAS_QUEUE,
timeout=60)
self.parse_persona_response(response)

def parse_persona_response(self, persona_response: dict):
if 'items' in persona_response:
self._persona_last_sync = int(time())
response_data = response.get('items', [])
response_data = persona_response.get('items', [])
personas = []
for item in response_data:
item.setdefault('name', item.pop('persona_name', None))
Expand All @@ -95,10 +116,18 @@ def _fetch_persona_config(self):
self.personas = personas

def start_sync(self):
"""
Update personas and start thread to periodically update from a service
on the MQ bus.
"""
self._fetch_persona_config()
self.persona_sync_thread.start()
if self.PERSONA_SYNC_INTERVAL > 0:
self.persona_sync_thread.start()

def stop_sync(self):
"""
Stop persona updates from the MQ bus.
"""
if self._persona_sync_thread:
self._persona_sync_thread.cancel()
self._persona_sync_thread = None
24 changes: 20 additions & 4 deletions neon_llm_core/utils/personas/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import time
from typing import Dict, Union, List
from typing import Dict, List, Optional

from neon_utils.logger import LOG

Expand All @@ -36,6 +36,10 @@


class PersonaHandlersState:
"""
This works with the PersonasProvider object to manage LLMBot instances for
all configured personas.
"""

def __init__(self, service_name: str, ovos_config: dict):
self._created_items: Dict[str, LLMBot] = {}
Expand All @@ -44,17 +48,29 @@ def __init__(self, service_name: str, ovos_config: dict):
self.mq_config = ovos_config.get('MQ', {})

def init_default_handlers(self):
"""
Initializes LLMBot instances for all personas defined in configuration.
"""
self._created_items = {}
if self.ovos_config.get("llm_bots", {}).get(self.service_name):
LOG.info(f"Chatbot(s) configured for: {self.service_name}")
for persona in self.ovos_config['llm_bots'][self.service_name]:
self.add_persona_handler(persona=PersonaModel.parse_obj(obj=persona))
self.add_persona_handler(
persona=PersonaModel.parse_obj(obj=persona))

def add_persona_handler(self, persona: PersonaModel) -> Union[LLMBot, None]:
def add_persona_handler(self, persona: PersonaModel) -> Optional[LLMBot]:
"""
Creates an `LLMBot` instance for the given persona if the persona does
not yet exist AND the persona is not disabled in configuration.
:param persona: Persona definition to generate an LLMBot instance of
:return: New or existing LLMBot instance or None if the persona is
disabled
"""
persona_dict = persona.model_dump()
if persona.id in list(self._created_items):
if self._created_items[persona.id].persona != persona_dict:
LOG.warning(f"Overriding already existing persona: '{persona.id}' with new data={persona}")
LOG.warning(f"Overriding already existing persona: "
f"'{persona.id}' with new data={persona}")
try:
self._created_items[persona.id].stop()
# time to gracefully stop the submind
Expand Down
Loading