diff --git a/src/program/media/item.py b/src/program/media/item.py index 4cec0767..e00ffd3a 100644 --- a/src/program/media/item.py +++ b/src/program/media/item.py @@ -6,10 +6,10 @@ import sqlalchemy from RTN import parse -from sqlalchemy import Index, UniqueConstraint +from sqlalchemy import Index from sqlalchemy.orm import Mapped, mapped_column, object_session, relationship -import utils.websockets.manager as ws_manager +from utils.sse_manager import sse_manager from program.db.db import db from program.media.state import States from program.media.subtitle import Subtitle @@ -133,9 +133,10 @@ def __init__(self, item: dict | None) -> None: self.subtitles = item.get("subtitles", []) def store_state(self) -> None: - if self.last_state and self.last_state != self._determine_state(): - ws_manager.send_item_update(json.dumps(self.to_dict())) - self.last_state = self._determine_state() + new_state = self._determine_state() + if self.last_state and self.last_state != new_state: + sse_manager.publish_event("item_update", {"last_state": self.last_state, "new_state": new_state, "item_id": self._id}) + self.last_state = new_state def is_stream_blacklisted(self, stream: Stream): """Check if a stream is blacklisted for this item.""" @@ -458,9 +459,7 @@ def _determine_state(self): def store_state(self) -> None: for season in self.seasons: season.store_state() - if self.last_state and self.last_state != self._determine_state(): - ws_manager.send_item_update(json.dumps(self.to_dict())) - self.last_state = self._determine_state() + super().store_state() def __repr__(self): return f"Show:{self.log_string}:{self.state.name}" @@ -531,9 +530,7 @@ class Season(MediaItem): def store_state(self) -> None: for episode in self.episodes: episode.store_state() - if self.last_state and self.last_state != self._determine_state(): - ws_manager.send_item_update(json.dumps(self.to_dict())) - self.last_state = self._determine_state() + super().store_state() def __init__(self, item): self.type = "season" diff --git a/src/program/media/stream.py b/src/program/media/stream.py index 2e803f13..4a8fcc81 100644 --- a/src/program/media/stream.py +++ b/src/program/media/stream.py @@ -24,7 +24,7 @@ class StreamRelation(db.Model): Index('ix_streamrelation_parent_id', 'parent_id'), Index('ix_streamrelation_child_id', 'child_id'), ) - + class StreamBlacklistRelation(db.Model): __tablename__ = "StreamBlacklistRelation" @@ -66,6 +66,6 @@ def __init__(self, torrent: Torrent): def __hash__(self): return self.infohash - + def __eq__(self, other): return isinstance(other, Stream) and self.infohash == other.infohash \ No newline at end of file diff --git a/src/program/program.py b/src/program/program.py index d03394d1..153d88cb 100644 --- a/src/program/program.py +++ b/src/program/program.py @@ -5,12 +5,10 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime from queue import Empty -from typing import Iterator, List from apscheduler.schedulers.background import BackgroundScheduler from rich.live import Live -import utils.websockets.manager as ws_manager from program.content import Listrr, Mdblist, Overseerr, PlexWatchlist, TraktContent from program.downloaders import Downloader from program.indexers.trakt import TraktIndexer @@ -170,7 +168,6 @@ def start(self): super().start() self.scheduler.start() logger.success("Riven is running!") - ws_manager.send_health_update("running") self.initialized = True def _retry_library(self) -> None: diff --git a/src/routers/__init__.py b/src/routers/__init__.py index 1f370d05..34e0816b 100644 --- a/src/routers/__init__.py +++ b/src/routers/__init__.py @@ -9,7 +9,7 @@ from routers.secure.settings import router as settings_router # from routers.secure.tmdb import router as tmdb_router from routers.secure.webhooks import router as webooks_router -from routers.secure.ws import router as ws_router +from routers.secure.stream import router as stream_router API_VERSION = "v1" @@ -27,4 +27,4 @@ async def root(_: Request) -> RootResponse: app_router.include_router(settings_router, dependencies=[Depends(resolve_api_key)]) # app_router.include_router(tmdb_router, dependencies=[Depends(resolve_api_key)]) app_router.include_router(webooks_router, dependencies=[Depends(resolve_api_key)]) -app_router.include_router(ws_router, dependencies=[Depends(resolve_api_key)]) \ No newline at end of file +app_router.include_router(stream_router, dependencies=[Depends(resolve_api_key)]) \ No newline at end of file diff --git a/src/routers/secure/stream.py b/src/routers/secure/stream.py new file mode 100644 index 00000000..65452a83 --- /dev/null +++ b/src/routers/secure/stream.py @@ -0,0 +1,39 @@ +from datetime import datetime +import json +from fastapi import APIRouter, Request +from fastapi.responses import StreamingResponse +from loguru import logger + +import logging + +from pydantic import BaseModel +from utils.sse_manager import sse_manager + + +router = APIRouter( + responses={404: {"description": "Not found"}}, + prefix="/stream", + tags=["stream"], +) + +class EventResponse(BaseModel): + data: dict + +class SSELogHandler(logging.Handler): + def emit(self, record: logging.LogRecord): + log_entry = { + "time": datetime.fromtimestamp(record.created).isoformat(), + "level": record.levelname, + "message": record.msg + } + sse_manager.publish_event("logging", json.dumps(log_entry)) + +logger.add(SSELogHandler()) + +@router.get("/event_types") +async def get_event_types(): + return {"message": list(sse_manager.event_queues.keys())} + +@router.get("/{event_type}") +async def stream_events(_: Request, event_type: str) -> EventResponse: + return StreamingResponse(sse_manager.subscribe(event_type), media_type="text/event-stream") \ No newline at end of file diff --git a/src/utils/event_manager.py b/src/utils/event_manager.py index 97290a64..32c1c442 100644 --- a/src/utils/event_manager.py +++ b/src/utils/event_manager.py @@ -1,16 +1,17 @@ import os import traceback - + from datetime import datetime from queue import Empty from threading import Lock +from typing import Dict, List from loguru import logger from pydantic import BaseModel from sqlalchemy.orm.exc import StaleDataError from concurrent.futures import CancelledError, Future, ThreadPoolExecutor -import utils.websockets.manager as ws_manager +from utils.sse_manager import sse_manager from program.db.db import db from program.db.db_functions import ( ensure_item_exists_in_db, @@ -73,7 +74,7 @@ def _process_future(self, future, service): result = next(future.result(), None) if future in self._futures: self._futures.remove(future) - ws_manager.send_event_update([future.event for future in self._futures if hasattr(future, "event")]) + sse_manager.publish_event("event_update", self.get_event_updates()) if isinstance(result, tuple): item_id, timestamp = result else: @@ -170,7 +171,7 @@ def submit_job(self, service, program, event=None): if event: future.event = event self._futures.append(future) - ws_manager.send_event_update([future.event for future in self._futures if hasattr(future, "event")]) + sse_manager.publish_event("event_update", self.get_event_updates()) future.add_done_callback(lambda f:self._process_future(f, service)) def cancel_job(self, item_id: int, suppress_logs=False): @@ -310,24 +311,14 @@ def add_item(self, item, service="Manual"): logger.debug(f"Added item with ID {item_id} to the queue.") - def get_event_updates(self) -> dict[str, list[EventUpdate]]: - """ - Returns a formatted list of event updates. - - Returns: - list: The list of formatted event updates. - """ + def get_event_updates(self) -> Dict[str, List[int]]: events = [future.event for future in self._futures if hasattr(future, "event")] event_types = ["Scraping", "Downloader", "Symlinker", "Updater", "PostProcessing"] - return { - event_type.lower(): [ - EventUpdate.model_validate( - { - "item_id": event.item_id, - "emitted_by": event.emitted_by if isinstance(event.emitted_by, str) else event.emitted_by.__name__, - "run_at": event.run_at.isoformat() - }) - for event in events if event.emitted_by == event_type - ] - for event_type in event_types - } \ No newline at end of file + + updates = {event_type: [] for event_type in event_types} + for event in events: + table = updates.get(event.emitted_by.__name__, None) + if table is not None: + table.append(event.item_id) + + return updates \ No newline at end of file diff --git a/src/utils/logging.py b/src/utils/logging.py index 5306f26b..5c355c53 100644 --- a/src/utils/logging.py +++ b/src/utils/logging.py @@ -1,6 +1,5 @@ """Logging utils""" -import asyncio import os import sys from datetime import datetime @@ -17,7 +16,6 @@ from program.settings.manager import settings_manager from utils import data_dir_path -from utils.websockets.logging_handler import Handler as WebSocketHandler LOG_ENABLED: bool = settings_manager.settings.log @@ -67,7 +65,7 @@ def get_log_settings(name, default_color, default_icon): warning_color, warning_icon = get_log_settings("WARNING", "ffcc00", "⚠️ ") critical_color, critical_icon = get_log_settings("CRITICAL", "ff0000", "") success_color, success_icon = get_log_settings("SUCCESS", "00ff00", "✔️ ") - + logger.level("DEBUG", color=debug_color, icon=debug_icon) logger.level("INFO", color=info_color, icon=info_icon) logger.level("WARNING", color=warning_color, icon=warning_icon) @@ -91,30 +89,18 @@ def get_log_settings(name, default_color, default_icon): "enqueue": True, }, { - "sink": log_filename, - "level": level.upper(), - "format": log_format, - "rotation": "25 MB", - "retention": "24 hours", - "compression": None, - "backtrace": False, + "sink": log_filename, + "level": level.upper(), + "format": log_format, + "rotation": "25 MB", + "retention": "24 hours", + "compression": None, + "backtrace": False, "diagnose": True, "enqueue": True, - }, - # maybe later - # { - # "sink": manager.send_log_message, - # "level": level.upper() or "INFO", - # "format": log_format, - # "backtrace": False, - # "diagnose": False, - # "enqueue": True, - # } + } ]) - logger.add(WebSocketHandler(), format=log_format) - - def log_cleaner(): """Remove old log files based on retention settings.""" cleaned = False @@ -130,7 +116,6 @@ def log_cleaner(): except Exception as e: logger.error(f"Failed to clean old logs: {e}") - def create_progress_bar(total_items: int) -> tuple[Progress, Console]: console = Console() progress = Progress( diff --git a/src/utils/sse_manager.py b/src/utils/sse_manager.py new file mode 100644 index 00000000..1b284d06 --- /dev/null +++ b/src/utils/sse_manager.py @@ -0,0 +1,26 @@ +import asyncio +from typing import Dict, Any + +class ServerSentEventManager: + def __init__(self): + self.event_queues: Dict[str, asyncio.Queue] = {} + + def publish_event(self, event_type: str, data: Any): + if not data: + return + if event_type not in self.event_queues: + self.event_queues[event_type] = asyncio.Queue() + self.event_queues[event_type].put_nowait(data) + + async def subscribe(self, event_type: str): + if event_type not in self.event_queues: + self.event_queues[event_type] = asyncio.Queue() + + while True: + try: + data = await asyncio.wait_for(self.event_queues[event_type].get(), timeout=1.0) + yield f"data: {data}\n\n" + except asyncio.TimeoutError: + pass + +sse_manager = ServerSentEventManager() \ No newline at end of file diff --git a/src/utils/websockets/logging_handler.py b/src/utils/websockets/logging_handler.py deleted file mode 100644 index 163c624e..00000000 --- a/src/utils/websockets/logging_handler.py +++ /dev/null @@ -1,12 +0,0 @@ -import logging - -from utils.websockets import manager - - -class Handler(logging.Handler): - def emit(self, record: logging.LogRecord): - try: - message = self.format(record) - manager.send_log_message(message) - except Exception: - self.handleError(record) \ No newline at end of file diff --git a/src/utils/websockets/manager.py b/src/utils/websockets/manager.py deleted file mode 100644 index fe20e14d..00000000 --- a/src/utils/websockets/manager.py +++ /dev/null @@ -1,59 +0,0 @@ -import asyncio -import json - -from fastapi import WebSocket -from loguru import logger - -active_connections = [] - -async def connect(websocket: WebSocket): - await websocket.accept() - existing_connection = next((connection for connection in active_connections if connection.app == websocket.app), None) - if not existing_connection: - logger.debug("Frontend connected!") - active_connections.append(websocket) - if websocket.app.program.initialized: - status = "running" - else: - status = "paused" - await websocket.send_json({"type": "health", "status": status}) - -def disconnect(websocket: WebSocket): - logger.debug("Frontend disconnected!") - existing_connection = next((connection for connection in active_connections if connection.app == websocket.app), None) - active_connections.remove(existing_connection) - -async def _send_json(message: json, websocket: WebSocket): - try: - await websocket.send_json(message) - except Exception: - pass - -def send_event_update(events: list): - event_types = ["Scraping", "Downloader", "Symlinker", "Updater", "PostProcessing"] - message = {event_type.lower(): [event.item_id for event in events if event.emitted_by == event_type] for event_type in event_types} - broadcast({"type": "event_update", "message": message}) - -def send_health_update(status: str): - broadcast({"type": "health", "status": status}) - -def send_log_message(message: str): - broadcast({"type": "log", "message": message}) - -def send_item_update(item: json): - broadcast({"type": "item_update", "item": item}) - -def broadcast(message: json): - for connection in active_connections: - event_loop = None - try: - event_loop = asyncio.get_event_loop() - except RuntimeError: - pass - try: - if event_loop and event_loop.is_running(): - asyncio.create_task(_send_json(message, connection)) - else: - asyncio.run(_send_json(message, connection)) - except Exception: - pass