Skip to content

Commit

Permalink
sio improvement for klat observer (#117)
Browse files Browse the repository at this point in the history
  • Loading branch information
NeonKirill authored Dec 13, 2024
1 parent 6b09fba commit a304386
Showing 1 changed file with 37 additions and 11 deletions.
48 changes: 37 additions & 11 deletions services/klatchat_observer/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,16 @@
import json
import re
import time
from queue import Queue

import cachetools.func

from threading import Event, Timer

import requests
import socketio

from socketio.exceptions import SocketIOError
from enum import Enum

from neon_mq_connector.utils.rabbit_utils import create_mq_callback
Expand Down Expand Up @@ -98,6 +101,8 @@ def __init__(

self._sio: socketio.Client = socketio.Client()
self.sio_connected = False
self.sio_connecting = False
self.sio_queued_messages = Queue(maxsize=256)
self.register_sio_handlers()

self.server_url = self.sio_url
Expand Down Expand Up @@ -339,6 +344,8 @@ def connect_sio(self):
"""
Method for establishing connection with Socket IO server
"""
# This flag is needed to lock parallel attempts
self.sio_connecting = True
self._sio.connect(
url=self.sio_url,
namespaces=["/"],
Expand All @@ -348,6 +355,13 @@ def connect_sio(self):
"nano_session": self._klat_nano_token,
},
)
time.sleep(1)
LOG.info("Socket IO reconnected")
self.sio_connecting = False
self.sio_connected = True
while not self.sio_queued_messages.empty():
self._sio_emit(**self.sio_queued_messages.get())
time.sleep(0.1)

@property
def sio(self):
Expand All @@ -357,14 +371,14 @@ def sio(self):
:return: connected async socket io instance
"""
if not self.sio_connected:
if not self.sio_connecting:
try:
# Assuming that Socket IO is connected unless Exception is raised during the connection
# This is done to prevent parallel invocation of this method from consumers
self.sio_connected = True
self.connect_sio()
except Exception as ex:
LOG.error(f"Failed to connect to sio: {ex}")
self.sio_connecting = False
self.sio_connected = False
return self._sio

Expand Down Expand Up @@ -573,7 +587,7 @@ def send_translation_response(self, data: dict):
"""
request_id = data.get("request_id", None)
if request_id and self.__translation_requests.pop(request_id, None):
self.sio.emit("get_neon_translations", data=data)
self._sio_emit("get_neon_translations", data=data)
else:
LOG.debug(
f"Neon translation response was not sent, "
Expand All @@ -591,7 +605,7 @@ def handle_get_prompt(self, body: dict):
if not requested_nick:
LOG.warning("Request from unknown sender, skipping")
return -1
self.sio.emit("get_prompt_data", data=body)
self._sio_emit("get_prompt_data", data=body)

@create_mq_callback()
def handle_neon_sync(self, body: dict):
Expand Down Expand Up @@ -653,7 +667,7 @@ def handle_neon_response(self, body: dict):
LOG.error(
f"Failed to set messageTTS with language={language}, gender={gender} - {ex}"
)
self.sio.emit("user_message", data=send_data)
self._sio_emit("user_message", data=send_data)
except Exception as ex:
LOG.error(f"Failed to emit Neon Chat API response: {ex}")

Expand Down Expand Up @@ -685,7 +699,7 @@ def handle_submind_shout(self, body: dict):
if all(required_key in list(body) for required_key in response_required_keys):
body.setdefault("timeCreated", int(time.time()))
body.setdefault("source", "klat_observer")
self.sio.emit("user_message", data=body)
self._sio_emit("user_message", data=body)
self.handle_message(data=body)
else:
error_msg = (
Expand Down Expand Up @@ -713,7 +727,7 @@ def handle_new_prompt(self, body: dict):
"prompt_text",
)
if all(required_key in list(body) for required_key in response_required_keys):
self.sio.emit("new_prompt", data=body)
self._sio_emit("new_prompt", data=body)
else:
error_msg = (
f"Skipping received data {body} as it lacks one of the required keys: "
Expand Down Expand Up @@ -745,7 +759,7 @@ def handle_saving_prompt_data(self, body: dict):
)

if all(required_key in list(body) for required_key in response_required_keys):
self.sio.emit("prompt_completed", data=body)
self._sio_emit("prompt_completed", data=body)
else:
error_msg = (
f"Skipping received data {body} as it lacks one of the required keys: "
Expand All @@ -763,20 +777,20 @@ def handle_saving_prompt_data(self, body: dict):
def on_stt_response(self, body: dict):
"""Handles receiving STT response"""
LOG.debug(f"Received STT Response: {body}")
self.sio.emit("stt_response", data=body)
self._sio_emit("stt_response", data=body)

@create_mq_callback()
def on_tts_response(self, body: dict):
"""Handles receiving TTS response"""
LOG.debug(f"Received TTS Response: {body}")
self.sio.emit("tts_response", data=body)
self._sio_emit("tts_response", data=body)

@create_mq_callback()
def on_subminds_state(self, body: dict):
"""Handles receiving subminds state message"""
LOG.debug(f"Received submind state: {body}")
body["msg_type"] = "subminds_state"
self.sio.emit("broadcast", data=body)
self._sio_emit("broadcast", data=body)

@create_mq_callback()
def on_get_configured_personas(self, body: dict):
Expand Down Expand Up @@ -885,3 +899,15 @@ def request_revoke_submind_ban_from_conversation(self, data: dict):
queue="revoke_submind_ban_from_conversation",
expiration=3000,
)

def _sio_emit(self, event: str, data: dict):
if self.sio_connected:
try:
self.sio.emit(event=event, data=data)
except SocketIOError as err:
LOG.error(
f"Socket IO event={event} emit failed due to error: {err}, reconnecting"
)
self.sio_connected = False
else:
self.sio_queued_messages.put(item={"event": event, "data": data})

0 comments on commit a304386

Please sign in to comment.