diff --git a/.env.example b/.env.example index cba5b9b95..2c5bcdc6c 100644 --- a/.env.example +++ b/.env.example @@ -15,6 +15,9 @@ HARD_RESET=false # This will attempt to fix broken symlinks in the library, and then exit after running! REPAIR_SYMLINKS=false +# Manual api key, must be 32 characters long +API_KEY=1234567890qwertyuiopas + # This is the number of workers to use for reindexing symlinks after a database reset. # More workers = faster symlinking but uses more memory. # For lower end machines, stick to around 1-3. diff --git a/src/auth.py b/src/auth.py new file mode 100644 index 000000000..db3ac9691 --- /dev/null +++ b/src/auth.py @@ -0,0 +1,14 @@ +from fastapi import HTTPException, Security, status +from fastapi.security import APIKeyHeader +from program.settings.manager import settings_manager + +api_key_header = APIKeyHeader(name="x-api-key") + +def resolve_api_key(api_key_header: str = Security(api_key_header)): + if api_key_header == settings_manager.settings.api_key: + return True + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Missing or invalid API key" + ) \ No newline at end of file diff --git a/src/controllers/models/shared.py b/src/controllers/models/shared.py deleted file mode 100644 index 53b5fefc0..000000000 --- a/src/controllers/models/shared.py +++ /dev/null @@ -1,5 +0,0 @@ -from pydantic import BaseModel - - -class MessageResponse(BaseModel): - message: str \ No newline at end of file diff --git a/src/main.py b/src/main.py index ea321b96b..d266d9b3e 100644 --- a/src/main.py +++ b/src/main.py @@ -8,19 +8,12 @@ import uvicorn from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.requests import Request - -from controllers.default import router as default_router -from controllers.items import router as items_router -from controllers.scrape import router as scrape_router -from controllers.settings import router as settings_router -from controllers.tmdb import router as tmdb_router -from controllers.webhooks import router as webhooks_router -from controllers.ws import router as ws_router -from scalar_fastapi import get_scalar_api_reference from program import Program from program.settings.models import get_version +from routers import app_router +from scalar_fastapi import get_scalar_api_reference +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request from utils.cli import handle_args from utils.logger import logger @@ -62,7 +55,6 @@ async def scalar_html(): ) app.program = Program() - app.add_middleware(LoguruMiddleware) app.add_middleware( CORSMiddleware, @@ -72,14 +64,7 @@ async def scalar_html(): allow_headers=["*"], ) -app.include_router(default_router) -app.include_router(settings_router) -app.include_router(items_router) -app.include_router(scrape_router) -app.include_router(webhooks_router) -app.include_router(tmdb_router) -app.include_router(ws_router) - +app.include_router(app_router) class Server(uvicorn.Server): def install_signal_handlers(self): diff --git a/src/program/settings/models.py b/src/program/settings/models.py index 2af11b9e3..de1a6864d 100644 --- a/src/program/settings/models.py +++ b/src/program/settings/models.py @@ -7,7 +7,7 @@ from RTN.models import SettingsModel from program.settings.migratable import MigratableBaseModel -from utils import root_dir +from utils import generate_api_key, get_version deprecation_warning = "This has been deprecated and will be removed in a future version." @@ -311,18 +311,6 @@ class RTNSettingsModel(SettingsModel, Observable): class IndexerModel(Observable): update_interval: int = 60 * 60 - -def get_version() -> str: - with open(root_dir / "pyproject.toml") as file: - pyproject_toml = file.read() - - match = re.search(r'version = "(.+)"', pyproject_toml) - if match: - version = match.group(1) - else: - raise ValueError("Could not find version in pyproject.toml") - return version - class LoggingModel(Observable): ... @@ -356,6 +344,7 @@ class PostProcessing(Observable): class AppModel(Observable): version: str = get_version() + api_key: str = "" debug: bool = True log: bool = True force_refresh: bool = False @@ -378,3 +367,6 @@ def __init__(self, **data: Any): super().__init__(**data) if existing_version < current_version: self.version = current_version + + if self.api_key == "": + self.api_key = generate_api_key() diff --git a/src/routers/__init__.py b/src/routers/__init__.py new file mode 100644 index 000000000..24ec207ec --- /dev/null +++ b/src/routers/__init__.py @@ -0,0 +1,30 @@ +from auth import resolve_api_key +from fastapi import Depends, Request +from fastapi.routing import APIRouter +from program.settings.manager import settings_manager +from routers.models.shared import RootResponse +from routers.secure.default import router as default_router +from routers.secure.items import router as items_router +from routers.secure.scrape import router as scrape_router +from routers.secure.settings import router as settings_router +from routers.secure.tmdb import router as tmdb_router +from routers.secure.webhooks import router as webooks_router +from routers.secure.ws import router as ws_router + +API_VERSION = "v1" + +app_router = APIRouter(prefix=f"/api/{API_VERSION}") +@app_router.get("/", operation_id="root") +async def root(_: Request) -> RootResponse: + return { + "message": "Riven is running!", + "version": settings_manager.settings.version, + } + +app_router.include_router(default_router, dependencies=[Depends(resolve_api_key)]) +app_router.include_router(items_router, dependencies=[Depends(resolve_api_key)]) +app_router.include_router(scrape_router, dependencies=[Depends(resolve_api_key)]) +app_router.include_router(settings_router, dependencies=[Depends(resolve_api_key)]) +app_router.include_router(tmdb_router, dependencies=[Depends(resolve_api_key)]) +app_router.include_router(webooks_router, dependencies=[Depends(resolve_api_key)]) +app_router.include_router(ws_router, dependencies=[Depends(resolve_api_key)]) \ No newline at end of file diff --git a/src/controllers/models/overseerr.py b/src/routers/models/overseerr.py similarity index 100% rename from src/controllers/models/overseerr.py rename to src/routers/models/overseerr.py diff --git a/src/controllers/models/plex.py b/src/routers/models/plex.py similarity index 100% rename from src/controllers/models/plex.py rename to src/routers/models/plex.py diff --git a/src/routers/models/shared.py b/src/routers/models/shared.py new file mode 100644 index 000000000..62b8ccad3 --- /dev/null +++ b/src/routers/models/shared.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + + +class MessageResponse(BaseModel): + message: str + +class RootResponse(MessageResponse): + version: str \ No newline at end of file diff --git a/src/routers/secure/__init__.py b/src/routers/secure/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/controllers/default.py b/src/routers/secure/default.py similarity index 92% rename from src/controllers/default.py rename to src/routers/secure/default.py index 9d349cb6c..143c68b81 100644 --- a/src/controllers/default.py +++ b/src/routers/secure/default.py @@ -1,234 +1,223 @@ -from typing import Literal - -import requests -from controllers.models.shared import MessageResponse -from fastapi import APIRouter, HTTPException, Request -from loguru import logger -from program.content.trakt import TraktContent -from program.db.db import db -from program.media.item import Episode, MediaItem, Movie, Season, Show -from program.media.state import States -from program.settings.manager import settings_manager -from pydantic import BaseModel, Field -from sqlalchemy import func, select -from utils.event_manager import EventUpdate - -router = APIRouter( - responses={404: {"description": "Not found"}}, -) - - -class RootResponse(MessageResponse): - version: str - - -@router.get("/", operation_id="root") -async def root() -> RootResponse: - return { - "message": "Riven is running!", - "version": settings_manager.settings.version, - } - - -@router.get("/health", operation_id="health") -async def health(request: Request) -> MessageResponse: - return { - "message": request.app.program.initialized, - } - - -class RDUser(BaseModel): - id: int - username: str - email: str - points: int = Field(description="User's RD points") - locale: str - avatar: str = Field(description="URL to the user's avatar") - type: Literal["free", "premium"] - premium: int = Field(description="Premium subscription left in seconds") - - -@router.get("/rd", operation_id="rd") -async def get_rd_user() -> RDUser: - api_key = settings_manager.settings.downloaders.real_debrid.api_key - headers = {"Authorization": f"Bearer {api_key}"} - - proxy = ( - settings_manager.settings.downloaders.real_debrid.proxy_url - if settings_manager.settings.downloaders.real_debrid.proxy_enabled - else None - ) - - response = requests.get( - "https://api.real-debrid.com/rest/1.0/user", - headers=headers, - proxies=proxy if proxy else None, - timeout=10, - ) - - if response.status_code != 200: - return {"success": False, "message": response.json()} - - return response.json() - - -@router.get("/torbox", operation_id="torbox") -async def get_torbox_user(): - api_key = settings_manager.settings.downloaders.torbox.api_key - headers = {"Authorization": f"Bearer {api_key}"} - response = requests.get( - "https://api.torbox.app/v1/api/user/me", headers=headers, timeout=10 - ) - return response.json() - - -@router.get("/services", operation_id="services") -async def get_services(request: Request) -> dict[str, bool]: - data = {} - if hasattr(request.app.program, "services"): - for service in request.app.program.all_services.values(): - data[service.key] = service.initialized - if not hasattr(service, "services"): - continue - for sub_service in service.services.values(): - data[sub_service.key] = sub_service.initialized - return data - - -class TraktOAuthInitiateResponse(BaseModel): - auth_url: str - - -@router.get("/trakt/oauth/initiate", operation_id="trakt_oauth_initiate") -async def initiate_trakt_oauth(request: Request) -> TraktOAuthInitiateResponse: - trakt = request.app.program.services.get(TraktContent) - if trakt is None: - raise HTTPException(status_code=404, detail="Trakt service not found") - auth_url = trakt.perform_oauth_flow() - return {"auth_url": auth_url} - - -@router.get("/trakt/oauth/callback", operation_id="trakt_oauth_callback") -async def trakt_oauth_callback(code: str, request: Request) -> MessageResponse: - trakt = request.app.program.services.get(TraktContent) - if trakt is None: - raise HTTPException(status_code=404, detail="Trakt service not found") - success = trakt.handle_oauth_callback(code) - if success: - return {"message": "OAuth token obtained successfully"} - else: - raise HTTPException(status_code=400, detail="Failed to obtain OAuth token") - - -class StatsResponse(BaseModel): - total_items: int - total_movies: int - total_shows: int - total_seasons: int - total_episodes: int - total_symlinks: int - incomplete_items: int - incomplete_retries: dict[int, int] = Field( - description="Media item log string: number of retries" - ) - states: dict[States, int] - - -@router.get("/stats", operation_id="stats") -async def get_stats(_: Request) -> StatsResponse: - payload = {} - with db.Session() as session: - # Ensure the connection is open for the entire duration of the session - with session.connection().execution_options(stream_results=True) as conn: - movies_symlinks = conn.execute(select(func.count(Movie._id)).where(Movie.symlinked == True)).scalar_one() - episodes_symlinks = conn.execute(select(func.count(Episode._id)).where(Episode.symlinked == True)).scalar_one() - total_symlinks = movies_symlinks + episodes_symlinks - - total_movies = conn.execute(select(func.count(Movie._id))).scalar_one() - total_shows = conn.execute(select(func.count(Show._id))).scalar_one() - total_seasons = conn.execute(select(func.count(Season._id))).scalar_one() - total_episodes = conn.execute(select(func.count(Episode._id))).scalar_one() - total_items = conn.execute(select(func.count(MediaItem._id))).scalar_one() - - # Use a server-side cursor for batch processing - incomplete_retries = {} - batch_size = 1000 - - result = conn.execute( - select(MediaItem._id, MediaItem.scraped_times) - .where(MediaItem.last_state != States.Completed) - ) - - while True: - batch = result.fetchmany(batch_size) - if not batch: - break - - for media_item_id, scraped_times in batch: - incomplete_retries[media_item_id] = scraped_times - - states = {} - for state in States: - states[state] = conn.execute(select(func.count(MediaItem._id)).where(MediaItem.last_state == state)).scalar_one() - - payload["total_items"] = total_items - payload["total_movies"] = total_movies - payload["total_shows"] = total_shows - payload["total_seasons"] = total_seasons - payload["total_episodes"] = total_episodes - payload["total_symlinks"] = total_symlinks - payload["incomplete_items"] = len(incomplete_retries) - payload["incomplete_retries"] = incomplete_retries - payload["states"] = states - - return payload - -class LogsResponse(BaseModel): - logs: str - - -@router.get("/logs", operation_id="logs") -async def get_logs() -> str: - log_file_path = None - for handler in logger._core.handlers.values(): - if ".log" in handler._name: - log_file_path = handler._sink._path - break - - if not log_file_path: - return {"success": False, "message": "Log file handler not found"} - - try: - with open(log_file_path, "r") as log_file: - log_contents = log_file.read() - return {"logs": log_contents} - except Exception as e: - logger.error(f"Failed to read log file: {e}") - raise HTTPException(status_code=500, detail="Failed to read log file") - - -@router.get("/events", operation_id="events") -async def get_events( - request: Request, -) -> dict[str, list[EventUpdate]]: - return request.app.program.em.get_event_updates() - - -@router.get("/mount", operation_id="mount") -async def get_rclone_files() -> dict[str, str]: - """Get all files in the rclone mount.""" - import os - - rclone_dir = settings_manager.settings.symlink.rclone_path - file_map = {} - - def scan_dir(path): - with os.scandir(path) as entries: - for entry in entries: - if entry.is_file(): - file_map[entry.name] = entry.path - elif entry.is_dir(): - scan_dir(entry.path) - - scan_dir(rclone_dir) # dict of `filename: filepath`` - return file_map +from typing import Literal + +import requests +from fastapi import APIRouter, HTTPException, Request +from loguru import logger +from program.content.trakt import TraktContent +from program.db.db import db +from program.media.item import Episode, MediaItem, Movie, Season, Show +from program.media.state import States +from program.settings.manager import settings_manager +from pydantic import BaseModel, Field +from sqlalchemy import func, select +from utils.event_manager import EventUpdate + +from ..models.shared import MessageResponse + +router = APIRouter( + responses={404: {"description": "Not found"}}, +) + + +@router.get("/health", operation_id="health") +async def health(request: Request) -> MessageResponse: + return { + "message": str(request.app.program.initialized), + } + + +class RDUser(BaseModel): + id: int + username: str + email: str + points: int = Field(description="User's RD points") + locale: str + avatar: str = Field(description="URL to the user's avatar") + type: Literal["free", "premium"] + premium: int = Field(description="Premium subscription left in seconds") + + +@router.get("/rd", operation_id="rd") +async def get_rd_user() -> RDUser: + api_key = settings_manager.settings.downloaders.real_debrid.api_key + headers = {"Authorization": f"Bearer {api_key}"} + + proxy = ( + settings_manager.settings.downloaders.real_debrid.proxy_url + if settings_manager.settings.downloaders.real_debrid.proxy_enabled + else None + ) + + response = requests.get( + "https://api.real-debrid.com/rest/1.0/user", + headers=headers, + proxies=proxy if proxy else None, + timeout=10, + ) + + if response.status_code != 200: + return {"success": False, "message": response.json()} + + return response.json() + + +@router.get("/torbox", operation_id="torbox") +async def get_torbox_user(): + api_key = settings_manager.settings.downloaders.torbox.api_key + headers = {"Authorization": f"Bearer {api_key}"} + response = requests.get( + "https://api.torbox.app/v1/api/user/me", headers=headers, timeout=10 + ) + return response.json() + + +@router.get("/services", operation_id="services") +async def get_services(request: Request) -> dict[str, bool]: + data = {} + if hasattr(request.app.program, "services"): + for service in request.app.program.all_services.values(): + data[service.key] = service.initialized + if not hasattr(service, "services"): + continue + for sub_service in service.services.values(): + data[sub_service.key] = sub_service.initialized + return data + + +class TraktOAuthInitiateResponse(BaseModel): + auth_url: str + + +@router.get("/trakt/oauth/initiate", operation_id="trakt_oauth_initiate") +async def initiate_trakt_oauth(request: Request) -> TraktOAuthInitiateResponse: + trakt = request.app.program.services.get(TraktContent) + if trakt is None: + raise HTTPException(status_code=404, detail="Trakt service not found") + auth_url = trakt.perform_oauth_flow() + return {"auth_url": auth_url} + + +@router.get("/trakt/oauth/callback", operation_id="trakt_oauth_callback") +async def trakt_oauth_callback(code: str, request: Request) -> MessageResponse: + trakt = request.app.program.services.get(TraktContent) + if trakt is None: + raise HTTPException(status_code=404, detail="Trakt service not found") + success = trakt.handle_oauth_callback(code) + if success: + return {"message": "OAuth token obtained successfully"} + else: + raise HTTPException(status_code=400, detail="Failed to obtain OAuth token") + + +class StatsResponse(BaseModel): + total_items: int + total_movies: int + total_shows: int + total_seasons: int + total_episodes: int + total_symlinks: int + incomplete_items: int + incomplete_retries: dict[int, int] = Field( + description="Media item log string: number of retries" + ) + states: dict[States, int] + + +@router.get("/stats", operation_id="stats") +async def get_stats(_: Request) -> StatsResponse: + payload = {} + with db.Session() as session: + # Ensure the connection is open for the entire duration of the session + with session.connection().execution_options(stream_results=True) as conn: + movies_symlinks = conn.execute(select(func.count(Movie._id)).where(Movie.symlinked == True)).scalar_one() + episodes_symlinks = conn.execute(select(func.count(Episode._id)).where(Episode.symlinked == True)).scalar_one() + total_symlinks = movies_symlinks + episodes_symlinks + + total_movies = conn.execute(select(func.count(Movie._id))).scalar_one() + total_shows = conn.execute(select(func.count(Show._id))).scalar_one() + total_seasons = conn.execute(select(func.count(Season._id))).scalar_one() + total_episodes = conn.execute(select(func.count(Episode._id))).scalar_one() + total_items = conn.execute(select(func.count(MediaItem._id))).scalar_one() + + # Use a server-side cursor for batch processing + incomplete_retries = {} + batch_size = 1000 + + result = conn.execute( + select(MediaItem._id, MediaItem.scraped_times) + .where(MediaItem.last_state != States.Completed) + ) + + while True: + batch = result.fetchmany(batch_size) + if not batch: + break + + for media_item_id, scraped_times in batch: + incomplete_retries[media_item_id] = scraped_times + + states = {} + for state in States: + states[state] = conn.execute(select(func.count(MediaItem._id)).where(MediaItem.last_state == state)).scalar_one() + + payload["total_items"] = total_items + payload["total_movies"] = total_movies + payload["total_shows"] = total_shows + payload["total_seasons"] = total_seasons + payload["total_episodes"] = total_episodes + payload["total_symlinks"] = total_symlinks + payload["incomplete_items"] = len(incomplete_retries) + payload["incomplete_retries"] = incomplete_retries + payload["states"] = states + + return payload + +class LogsResponse(BaseModel): + logs: str + + +@router.get("/logs", operation_id="logs") +async def get_logs() -> str: + log_file_path = None + for handler in logger._core.handlers.values(): + if ".log" in handler._name: + log_file_path = handler._sink._path + break + + if not log_file_path: + return {"success": False, "message": "Log file handler not found"} + + try: + with open(log_file_path, "r") as log_file: + log_contents = log_file.read() + return {"logs": log_contents} + except Exception as e: + logger.error(f"Failed to read log file: {e}") + raise HTTPException(status_code=500, detail="Failed to read log file") + + +@router.get("/events", operation_id="events") +async def get_events( + request: Request, +) -> dict[str, list[EventUpdate]]: + return request.app.program.em.get_event_updates() + + +@router.get("/mount", operation_id="mount") +async def get_rclone_files() -> dict[str, str]: + """Get all files in the rclone mount.""" + import os + + rclone_dir = settings_manager.settings.symlink.rclone_path + file_map = {} + + def scan_dir(path): + with os.scandir(path) as entries: + for entry in entries: + if entry.is_file(): + file_map[entry.name] = entry.path + elif entry.is_dir(): + scan_dir(entry.path) + + scan_dir(rclone_dir) # dict of `filename: filepath`` + return file_map diff --git a/src/controllers/items.py b/src/routers/secure/items.py similarity index 98% rename from src/controllers/items.py rename to src/routers/secure/items.py index c8c548f6e..1ae15a295 100644 --- a/src/controllers/items.py +++ b/src/routers/secure/items.py @@ -3,12 +3,6 @@ from typing import Literal, Optional import Levenshtein -from RTN import Torrent -from fastapi import APIRouter, HTTPException, Request -from sqlalchemy import func, select -from sqlalchemy.exc import NoResultFound - -from controllers.models.shared import MessageResponse from fastapi import APIRouter, HTTPException, Request from program.content import Overseerr from program.db.db import db @@ -20,12 +14,12 @@ get_parent_ids, reset_media_item, ) +from program.downloaders import Downloader, get_needed_media from program.media.item import MediaItem from program.media.state import States -from program.symlink import Symlinker -from program.downloaders import Downloader, get_needed_media from program.media.stream import Stream from program.scrapers.shared import rtn +from program.symlink import Symlinker from program.types import Event from pydantic import BaseModel from RTN import Torrent @@ -33,6 +27,8 @@ from sqlalchemy.exc import NoResultFound from utils.logger import logger +from ..models.shared import MessageResponse + router = APIRouter( prefix="/items", tags=["items"], diff --git a/src/controllers/scrape.py b/src/routers/secure/scrape.py similarity index 99% rename from src/controllers/scrape.py rename to src/routers/secure/scrape.py index e975705bf..0985221b9 100644 --- a/src/controllers/scrape.py +++ b/src/routers/secure/scrape.py @@ -94,7 +94,7 @@ async def scrape( for stream in results.values() ] - except StopIteration as e: + except StopIteration: raise HTTPException(status_code=204, detail="Media item not found") except Exception as e: raise HTTPException(status_code=500, detail=str(e)) diff --git a/src/controllers/settings.py b/src/routers/secure/settings.py similarity index 95% rename from src/controllers/settings.py rename to src/routers/secure/settings.py index d02375c4f..bfc78a6cd 100644 --- a/src/controllers/settings.py +++ b/src/routers/secure/settings.py @@ -1,129 +1,130 @@ -from copy import copy -from typing import Any, Dict, List - -from controllers.models.shared import MessageResponse -from fastapi import APIRouter, HTTPException -from program.settings.manager import settings_manager -from program.settings.models import AppModel -from pydantic import BaseModel, ValidationError - - -class SetSettings(BaseModel): - key: str - value: Any - - -router = APIRouter( - prefix="/settings", - tags=["settings"], - responses={404: {"description": "Not found"}}, -) - - -@router.get("/schema", operation_id="get_settings_schema") -async def get_settings_schema() -> dict[str, Any]: - """ - Get the JSON schema for the settings. - """ - return settings_manager.settings.model_json_schema() - -@router.get("/load", operation_id="load_settings") -async def load_settings() -> MessageResponse: - settings_manager.load() - return { - "message": "Settings loaded!", - } - -@router.post("/save", operation_id="save_settings") -async def save_settings() -> MessageResponse: - settings_manager.save() - return { - "message": "Settings saved!", - } - - -@router.get("/get/all", operation_id="get_all_settings") -async def get_all_settings() -> AppModel: - return copy(settings_manager.settings) - - -@router.get("/get/{paths}", operation_id="get_settings") -async def get_settings(paths: str) -> dict[str, Any]: - current_settings = settings_manager.settings.model_dump() - data = {} - for path in paths.split(","): - keys = path.split(".") - current_obj = current_settings - - for k in keys: - if k not in current_obj: - return None - current_obj = current_obj[k] - - data[path] = current_obj - return data - - -@router.post("/set/all", operation_id="set_all_settings") -async def set_all_settings(new_settings: Dict[str, Any]) -> MessageResponse: - current_settings = settings_manager.settings.model_dump() - - def update_settings(current_obj, new_obj): - for key, value in new_obj.items(): - if isinstance(value, dict) and key in current_obj: - update_settings(current_obj[key], value) - else: - current_obj[key] = value - - update_settings(current_settings, new_settings) - - # Validate and save the updated settings - try: - updated_settings = settings_manager.settings.model_validate(current_settings) - settings_manager.load(settings_dict=updated_settings.model_dump()) - settings_manager.save() # Ensure the changes are persisted - except Exception as e: - raise HTTPException(status_code=400, detail=str(e)) - - return { - "message": "All settings updated successfully!", - } - -@router.post("/set", operation_id="set_settings") -async def set_settings(settings: List[SetSettings]) -> MessageResponse: - current_settings = settings_manager.settings.model_dump() - - for setting in settings: - keys = setting.key.split(".") - current_obj = current_settings - - # Navigate to the last key's parent object, ensuring all keys exist. - for k in keys[:-1]: - if k not in current_obj: - raise HTTPException( - status_code=400, - detail=f"Path '{'.'.join(keys[:-1])}' does not exist.", - ) - current_obj = current_obj[k] - - # Ensure the final key exists before setting the value. - if keys[-1] in current_obj: - current_obj[keys[-1]] = setting.value - else: - raise HTTPException( - status_code=400, - detail=f"Key '{keys[-1]}' does not exist in path '{'.'.join(keys[:-1])}'.", - ) - - # Validate and apply the updated settings to the AppModel instance - try: - updated_settings = settings_manager.settings.__class__(**current_settings) - settings_manager.load(settings_dict=updated_settings.model_dump()) - settings_manager.save() # Ensure the changes are persisted - except ValidationError as e: - raise HTTPException from e( - status_code=400, - detail=f"Failed to update settings: {str(e)}", - ) - - return {"message": "Settings updated successfully."} +from copy import copy +from typing import Any, Dict, List + +from fastapi import APIRouter, HTTPException +from program.settings.manager import settings_manager +from program.settings.models import AppModel +from pydantic import BaseModel, ValidationError + +from ..models.shared import MessageResponse + + +class SetSettings(BaseModel): + key: str + value: Any + + +router = APIRouter( + prefix="/settings", + tags=["settings"], + responses={404: {"description": "Not found"}}, +) + + +@router.get("/schema", operation_id="get_settings_schema") +async def get_settings_schema() -> dict[str, Any]: + """ + Get the JSON schema for the settings. + """ + return settings_manager.settings.model_json_schema() + +@router.get("/load", operation_id="load_settings") +async def load_settings() -> MessageResponse: + settings_manager.load() + return { + "message": "Settings loaded!", + } + +@router.post("/save", operation_id="save_settings") +async def save_settings() -> MessageResponse: + settings_manager.save() + return { + "message": "Settings saved!", + } + + +@router.get("/get/all", operation_id="get_all_settings") +async def get_all_settings() -> AppModel: + return copy(settings_manager.settings) + + +@router.get("/get/{paths}", operation_id="get_settings") +async def get_settings(paths: str) -> dict[str, Any]: + current_settings = settings_manager.settings.model_dump() + data = {} + for path in paths.split(","): + keys = path.split(".") + current_obj = current_settings + + for k in keys: + if k not in current_obj: + return None + current_obj = current_obj[k] + + data[path] = current_obj + return data + + +@router.post("/set/all", operation_id="set_all_settings") +async def set_all_settings(new_settings: Dict[str, Any]) -> MessageResponse: + current_settings = settings_manager.settings.model_dump() + + def update_settings(current_obj, new_obj): + for key, value in new_obj.items(): + if isinstance(value, dict) and key in current_obj: + update_settings(current_obj[key], value) + else: + current_obj[key] = value + + update_settings(current_settings, new_settings) + + # Validate and save the updated settings + try: + updated_settings = settings_manager.settings.model_validate(current_settings) + settings_manager.load(settings_dict=updated_settings.model_dump()) + settings_manager.save() # Ensure the changes are persisted + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + + return { + "message": "All settings updated successfully!", + } + +@router.post("/set", operation_id="set_settings") +async def set_settings(settings: List[SetSettings]) -> MessageResponse: + current_settings = settings_manager.settings.model_dump() + + for setting in settings: + keys = setting.key.split(".") + current_obj = current_settings + + # Navigate to the last key's parent object, ensuring all keys exist. + for k in keys[:-1]: + if k not in current_obj: + raise HTTPException( + status_code=400, + detail=f"Path '{'.'.join(keys[:-1])}' does not exist.", + ) + current_obj = current_obj[k] + + # Ensure the final key exists before setting the value. + if keys[-1] in current_obj: + current_obj[keys[-1]] = setting.value + else: + raise HTTPException( + status_code=400, + detail=f"Key '{keys[-1]}' does not exist in path '{'.'.join(keys[:-1])}'.", + ) + + # Validate and apply the updated settings to the AppModel instance + try: + updated_settings = settings_manager.settings.__class__(**current_settings) + settings_manager.load(settings_dict=updated_settings.model_dump()) + settings_manager.save() # Ensure the changes are persisted + except ValidationError as e: + raise HTTPException from e( + status_code=400, + detail=f"Failed to update settings: {str(e)}", + ) + + return {"message": "Settings updated successfully."} diff --git a/src/controllers/tmdb.py b/src/routers/secure/tmdb.py similarity index 100% rename from src/controllers/tmdb.py rename to src/routers/secure/tmdb.py diff --git a/src/controllers/webhooks.py b/src/routers/secure/webhooks.py similarity index 98% rename from src/controllers/webhooks.py rename to src/routers/secure/webhooks.py index 097f6c1a1..3f6dd500c 100644 --- a/src/controllers/webhooks.py +++ b/src/routers/secure/webhooks.py @@ -9,7 +9,7 @@ from requests import RequestException from utils.logger import logger -from .models.overseerr import OverseerrWebhook +from ..models.overseerr import OverseerrWebhook router = APIRouter( prefix="/webhook", diff --git a/src/controllers/ws.py b/src/routers/secure/ws.py similarity index 100% rename from src/controllers/ws.py rename to src/routers/secure/ws.py diff --git a/src/utils/__init__.py b/src/utils/__init__.py index e2f85f0e0..f5faed3f0 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -1,6 +1,38 @@ +import os +import re +import secrets +import string from pathlib import Path +from loguru import logger + root_dir = Path(__file__).resolve().parents[2] data_dir_path = root_dir / "data" -alembic_dir = data_dir_path / "alembic" \ No newline at end of file +alembic_dir = data_dir_path / "alembic" + +def get_version() -> str: + with open(root_dir / "pyproject.toml") as file: + pyproject_toml = file.read() + + match = re.search(r'version = "(.+)"', pyproject_toml) + if match: + version = match.group(1) + else: + raise ValueError("Could not find version in pyproject.toml") + return version + +def generate_api_key(): + """Generate a secure API key of the specified length.""" + API_KEY = os.getenv("API_KEY") + if len(API_KEY) != 32: + logger.warning("env.API_KEY is not 32 characters long, generating a new one...") + characters = string.ascii_letters + string.digits + + # Generate the API key + api_key = "".join(secrets.choice(characters) for _ in range(32)) + logger.warning(f"New api key: {api_key}") + else: + api_key = API_KEY + + return api_key \ No newline at end of file