Skip to content

Commit

Permalink
feat: we now server sse via /stream
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaisberg authored and Gaisberg committed Oct 21, 2024
1 parent 28305ce commit efbc471
Show file tree
Hide file tree
Showing 10 changed files with 100 additions and 136 deletions.
19 changes: 8 additions & 11 deletions src/program/media/item.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

import sqlalchemy
from RTN import parse
from sqlalchemy import Index, UniqueConstraint
from sqlalchemy import Index
from sqlalchemy.orm import Mapped, mapped_column, object_session, relationship

import utils.websockets.manager as ws_manager
from utils.sse_manager import sse_manager
from program.db.db import db
from program.media.state import States
from program.media.subtitle import Subtitle
Expand Down Expand Up @@ -133,9 +133,10 @@ def __init__(self, item: dict | None) -> None:
self.subtitles = item.get("subtitles", [])

def store_state(self) -> None:
if self.last_state and self.last_state != self._determine_state():
ws_manager.send_item_update(json.dumps(self.to_dict()))
self.last_state = self._determine_state()
new_state = self._determine_state()
if self.last_state and self.last_state != new_state:
sse_manager.publish_event("item_update", {"last_state": self.last_state, "new_state": new_state, "item_id": self._id})
self.last_state = new_state

def is_stream_blacklisted(self, stream: Stream):
"""Check if a stream is blacklisted for this item."""
Expand Down Expand Up @@ -458,9 +459,7 @@ def _determine_state(self):
def store_state(self) -> None:
for season in self.seasons:
season.store_state()
if self.last_state and self.last_state != self._determine_state():
ws_manager.send_item_update(json.dumps(self.to_dict()))
self.last_state = self._determine_state()
super().store_state()

def __repr__(self):
return f"Show:{self.log_string}:{self.state.name}"
Expand Down Expand Up @@ -531,9 +530,7 @@ class Season(MediaItem):
def store_state(self) -> None:
for episode in self.episodes:
episode.store_state()
if self.last_state and self.last_state != self._determine_state():
ws_manager.send_item_update(json.dumps(self.to_dict()))
self.last_state = self._determine_state()
super().store_state()

def __init__(self, item):
self.type = "season"
Expand Down
4 changes: 2 additions & 2 deletions src/program/media/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class StreamRelation(db.Model):
Index('ix_streamrelation_parent_id', 'parent_id'),
Index('ix_streamrelation_child_id', 'child_id'),
)

class StreamBlacklistRelation(db.Model):
__tablename__ = "StreamBlacklistRelation"

Expand Down Expand Up @@ -66,6 +66,6 @@ def __init__(self, torrent: Torrent):

def __hash__(self):
return self.infohash

def __eq__(self, other):
return isinstance(other, Stream) and self.infohash == other.infohash
3 changes: 0 additions & 3 deletions src/program/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from queue import Empty
from typing import Iterator, List

from apscheduler.schedulers.background import BackgroundScheduler
from rich.live import Live

import utils.websockets.manager as ws_manager
from program.content import Listrr, Mdblist, Overseerr, PlexWatchlist, TraktContent
from program.downloaders import Downloader
from program.indexers.trakt import TraktIndexer
Expand Down Expand Up @@ -170,7 +168,6 @@ def start(self):
super().start()
self.scheduler.start()
logger.success("Riven is running!")
ws_manager.send_health_update("running")
self.initialized = True

def _retry_library(self) -> None:
Expand Down
4 changes: 2 additions & 2 deletions src/routers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from routers.secure.settings import router as settings_router
# from routers.secure.tmdb import router as tmdb_router
from routers.secure.webhooks import router as webooks_router
from routers.secure.ws import router as ws_router
from routers.secure.stream import router as stream_router

API_VERSION = "v1"

Expand All @@ -27,4 +27,4 @@ async def root(_: Request) -> RootResponse:
app_router.include_router(settings_router, dependencies=[Depends(resolve_api_key)])
# app_router.include_router(tmdb_router, dependencies=[Depends(resolve_api_key)])
app_router.include_router(webooks_router, dependencies=[Depends(resolve_api_key)])
app_router.include_router(ws_router, dependencies=[Depends(resolve_api_key)])
app_router.include_router(stream_router, dependencies=[Depends(resolve_api_key)])
39 changes: 39 additions & 0 deletions src/routers/secure/stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from datetime import datetime
import json
from fastapi import APIRouter, Request
from fastapi.responses import StreamingResponse
from loguru import logger

import logging

from pydantic import BaseModel
from utils.sse_manager import sse_manager


router = APIRouter(
responses={404: {"description": "Not found"}},
prefix="/stream",
tags=["stream"],
)

class EventResponse(BaseModel):
data: dict

class SSELogHandler(logging.Handler):
def emit(self, record: logging.LogRecord):
log_entry = {
"time": datetime.fromtimestamp(record.created).isoformat(),
"level": record.levelname,
"message": record.msg
}
sse_manager.publish_event("logging", json.dumps(log_entry))

logger.add(SSELogHandler())

@router.get("/event_types")
async def get_event_types():
return {"message": list(sse_manager.event_queues.keys())}

@router.get("/{event_type}")
async def stream_events(_: Request, event_type: str) -> EventResponse:
return StreamingResponse(sse_manager.subscribe(event_type), media_type="text/event-stream")
37 changes: 14 additions & 23 deletions src/utils/event_manager.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import os
import traceback

from datetime import datetime
from queue import Empty
from threading import Lock
from typing import Dict, List

from loguru import logger
from pydantic import BaseModel
from sqlalchemy.orm.exc import StaleDataError
from concurrent.futures import CancelledError, Future, ThreadPoolExecutor

import utils.websockets.manager as ws_manager
from utils.sse_manager import sse_manager
from program.db.db import db
from program.db.db_functions import (
ensure_item_exists_in_db,
Expand Down Expand Up @@ -73,7 +74,7 @@ def _process_future(self, future, service):
result = next(future.result(), None)
if future in self._futures:
self._futures.remove(future)
ws_manager.send_event_update([future.event for future in self._futures if hasattr(future, "event")])
sse_manager.publish_event("event_update", self.get_event_updates())
if isinstance(result, tuple):
item_id, timestamp = result
else:
Expand Down Expand Up @@ -170,7 +171,7 @@ def submit_job(self, service, program, event=None):
if event:
future.event = event
self._futures.append(future)
ws_manager.send_event_update([future.event for future in self._futures if hasattr(future, "event")])
sse_manager.publish_event("event_update", self.get_event_updates())
future.add_done_callback(lambda f:self._process_future(f, service))

def cancel_job(self, item_id: int, suppress_logs=False):
Expand Down Expand Up @@ -310,24 +311,14 @@ def add_item(self, item, service="Manual"):
logger.debug(f"Added item with ID {item_id} to the queue.")


def get_event_updates(self) -> dict[str, list[EventUpdate]]:
"""
Returns a formatted list of event updates.
Returns:
list: The list of formatted event updates.
"""
def get_event_updates(self) -> Dict[str, List[int]]:
events = [future.event for future in self._futures if hasattr(future, "event")]
event_types = ["Scraping", "Downloader", "Symlinker", "Updater", "PostProcessing"]
return {
event_type.lower(): [
EventUpdate.model_validate(
{
"item_id": event.item_id,
"emitted_by": event.emitted_by if isinstance(event.emitted_by, str) else event.emitted_by.__name__,
"run_at": event.run_at.isoformat()
})
for event in events if event.emitted_by == event_type
]
for event_type in event_types
}

updates = {event_type: [] for event_type in event_types}
for event in events:
table = updates.get(event.emitted_by.__name__, None)
if table is not None:
table.append(event.item_id)

return updates
33 changes: 9 additions & 24 deletions src/utils/logging.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Logging utils"""

import asyncio
import os
import sys
from datetime import datetime
Expand All @@ -17,7 +16,6 @@

from program.settings.manager import settings_manager
from utils import data_dir_path
from utils.websockets.logging_handler import Handler as WebSocketHandler

LOG_ENABLED: bool = settings_manager.settings.log

Expand Down Expand Up @@ -67,7 +65,7 @@ def get_log_settings(name, default_color, default_icon):
warning_color, warning_icon = get_log_settings("WARNING", "ffcc00", "⚠️ ")
critical_color, critical_icon = get_log_settings("CRITICAL", "ff0000", "")
success_color, success_icon = get_log_settings("SUCCESS", "00ff00", "✔️ ")

logger.level("DEBUG", color=debug_color, icon=debug_icon)
logger.level("INFO", color=info_color, icon=info_icon)
logger.level("WARNING", color=warning_color, icon=warning_icon)
Expand All @@ -91,30 +89,18 @@ def get_log_settings(name, default_color, default_icon):
"enqueue": True,
},
{
"sink": log_filename,
"level": level.upper(),
"format": log_format,
"rotation": "25 MB",
"retention": "24 hours",
"compression": None,
"backtrace": False,
"sink": log_filename,
"level": level.upper(),
"format": log_format,
"rotation": "25 MB",
"retention": "24 hours",
"compression": None,
"backtrace": False,
"diagnose": True,
"enqueue": True,
},
# maybe later
# {
# "sink": manager.send_log_message,
# "level": level.upper() or "INFO",
# "format": log_format,
# "backtrace": False,
# "diagnose": False,
# "enqueue": True,
# }
}
])

logger.add(WebSocketHandler(), format=log_format)


def log_cleaner():
"""Remove old log files based on retention settings."""
cleaned = False
Expand All @@ -130,7 +116,6 @@ def log_cleaner():
except Exception as e:
logger.error(f"Failed to clean old logs: {e}")


def create_progress_bar(total_items: int) -> tuple[Progress, Console]:
console = Console()
progress = Progress(
Expand Down
26 changes: 26 additions & 0 deletions src/utils/sse_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import asyncio
from typing import Dict, Any

class ServerSentEventManager:
def __init__(self):
self.event_queues: Dict[str, asyncio.Queue] = {}

def publish_event(self, event_type: str, data: Any):
if not data:
return
if event_type not in self.event_queues:
self.event_queues[event_type] = asyncio.Queue()
self.event_queues[event_type].put_nowait(data)

async def subscribe(self, event_type: str):
if event_type not in self.event_queues:
self.event_queues[event_type] = asyncio.Queue()

while True:
try:
data = await asyncio.wait_for(self.event_queues[event_type].get(), timeout=1.0)
yield f"data: {data}\n\n"
except asyncio.TimeoutError:
pass

sse_manager = ServerSentEventManager()
12 changes: 0 additions & 12 deletions src/utils/websockets/logging_handler.py

This file was deleted.

59 changes: 0 additions & 59 deletions src/utils/websockets/manager.py

This file was deleted.

0 comments on commit efbc471

Please sign in to comment.