diff --git a/admin/admin.py b/admin/admin.py index 1d9368f6..f534076d 100644 --- a/admin/admin.py +++ b/admin/admin.py @@ -30,6 +30,7 @@ from werkzeug.utils import secure_filename import api +from cache_config import clear_api_cache_keys, clear_view_cache_keys from util import get_theme_directories, length_check # Blueprint configuration @@ -175,6 +176,8 @@ def save_settings(settings: Dict[str, Any], flash_msg: str) -> Response: # Load the theme template if the current theme is changed set_theme_loader(app, remove_cache=True) + clear_view_cache_keys(all_users=True) + clear_api_cache_keys("admin_save_settings") flash(flash_msg, 'success') return redirect(url_for("admin_bp.index")) diff --git a/api.py b/api.py index b16463dc..ef830a0c 100644 --- a/api.py +++ b/api.py @@ -4,6 +4,7 @@ __license__ = 'GPLv3, see LICENSE' import base64 +import hashlib import json import re import sys @@ -15,6 +16,7 @@ from flask import current_app as app from irods import message, rule +from cache_config import cache, clear_api_cache_keys, get_api_cache_timeout, make_key from errors import InvalidAPIError, UnauthorizedAPIAccessError from util import log_error @@ -23,51 +25,116 @@ @api_bp.route('/', methods=['POST']) def _call(fn: str) -> Response: + """Handle API calls to specified function. + + :param fn: The name of the API function to call + + :returns: JSON response containing the result of the API call + + :raises UnauthorizedAPIAccessError: If the user is not authenticated + :raises InvalidAPIError: If the function name is invalid + """ if not authenticated(): raise UnauthorizedAPIAccessError - if not re.match("^([a-z_]+)$", fn): + if not re.match(r"^[a-z_]+$", fn): raise InvalidAPIError - data: Dict[str, Any] = {} - if 'data' in request.form: - data = json.loads(request.form['data']) + data = json.loads(request.form.get('data', '{}')) + result = call(fn, data) + return jsonify(result), get_response_code(result) - result: Dict[str, Any] = call(fn, data) - code: int = 200 - if result['status'] == 'error_internal': - code = 500 - elif result['status'] != 'ok': - code = 400 +def get_response_code(result: Dict[str, Any]) -> int: + """Determine the HTTP response code based on the result status. - response = jsonify(result) - response.status_code = code - return response + :param result: The result dictionary from the API call + + :returns: HTTP status code + """ + if result['status'] == 'error_internal': + return 500 + return 400 if result['status'] != 'ok' else 200 def call(fn: str, data: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + """Call the specified API function with the provided data. + + :param fn: The name of the API function to call + :param data: Optional dictionary of data to pass to the function + + :returns: The result of the API call as a dictionary + """ + if app.config.get('LOG_API_CALL_DURATION', False): + begintime = timer() + + if data is None: + data = {} + + params = json.dumps(data) + encoded_params = hashlib.shake_256(params.encode('utf-8')).hexdigest(20) + + # Clear API cache keys if the API function called impacts keys. + clear_api_cache_keys(fn) + + timeout = get_api_cache_timeout(fn) + cached_result = None + if timeout > 0: + cached_result = cache.get(make_key(f"{fn}-{encoded_params}")) + + # Execute rule if there is no cached result. + if cached_result is None: + result = execute_rule(fn, params) + + # Cache result if a timeout is specified for this API. + if timeout > 0: + cache.set(make_key(f"{fn}-{encoded_params}"), result, timeout=timeout) + else: + result = cached_result + + if app.config.get('LOG_API_CALL_DURATION', False): + endtime = timer() + callduration = round((endtime - begintime) * 1000) + log_message = f"DEBUG: {callduration:4d}ms api_{fn} {params}" + if cached_result is not None: + log_message += " (from cache)" + print(log_message, file=sys.stderr) + + return json.loads(result) + + +def execute_rule(fn: str, params: str) -> str: + """Execute the specified iRODS rule with the given parameters. + + :param fn: The name of the API function to execute + :param params: The parameters to pass to the rule + + :returns: The output of the rule execution as a string. + """ def bytesbuf_to_str(s: message.BinBytesBuf) -> str: + """Convert a BinBytesBuf to a string, handling null termination.""" s = s.buf[:s.buflen] i = s.find(b'\x00') return s if i < 0 else s[:i] def escape_quotes(s: str) -> str: + """Escape quotes in a string for safe inclusion in rules.""" return s.replace('\\', '\\\\').replace('"', '\\"') def break_strings(N: int, m: int) -> int: + """Calculate the number of segments needed to break a string.""" return (N - 1) // m + 1 def nrep_string_expr(s: str, m: int = 64) -> str: - return '++\n'.join(f'"{escape_quotes(s[i * m:i * m + m])}"' for i in range(break_strings(len(s), m) + 1)) + """Break up the string literal to work around limits for both parameter strings + and literal string constants in the iRODS core code. - if app.config.get('LOG_API_CALL_DURATION', False): - begintime = timer() - - if data is None: - data = {} + :param s: The string to be broken + :param m: The maximum length of each segment - params = json.dumps(data) + :returns: A string formatted for iRODS rule input + """ + return '++\n'.join(f'"{escape_quotes(s[i * m:i * m + m])}"' for i in range(break_strings(len(s), m) + 1)) # Compress params and encode as base64 to reduce size (max rule length in iRODS is 20KB) compressed_params = zlib.compress(params.encode()) @@ -91,35 +158,36 @@ def nrep_string_expr(s: str, m: int = 64) -> str: g.irods.cleanup() x = x.execute(session_cleanup=False) - x = bytesbuf_to_str(x._values['MsParam_PI'][0]._values['inOutStruct']._values['stdoutBuf']) - - result = x.decode() - - if app.config.get('LOG_API_CALL_DURATION', False): - endtime = timer() - callduration = round((endtime - begintime) * 1000) - print(f"DEBUG: {callduration:4d}ms api_{fn} {params}", file=sys.stderr) - - return json.loads(result) + return bytesbuf_to_str(x._values['MsParam_PI'][0]._values['inOutStruct']._values['stdoutBuf']) def authenticated() -> bool: + """Check if the user is authenticated. + + :returns: True if the user is authenticated, False otherwise + """ return g.get('user') is not None and g.get('irods') is not None @api_bp.errorhandler(Exception) def api_error_handler(error: Exception) -> Response: + """Handle exceptions raised during API calls. + + :param error: The exception that was raised + + :returns: A JSON response containing the error details and HTTP status code + """ log_error(f'API Error: {error}', True) status = "internal_error" status_info = "Something went wrong" data: Dict[str, Any] = {} - code = 500 + code = 500 # Default to internal server error. - if type(error) is InvalidAPIError: + # Determine specific error types and set appropriate response details. + if isinstance(error, InvalidAPIError): code = 400 status_info = "Bad API request" - - if type(error) is UnauthorizedAPIAccessError: + elif isinstance(error, UnauthorizedAPIAccessError): code = 401 status_info = "Not authorized to use the API" @@ -128,4 +196,5 @@ def api_error_handler(error: Exception) -> Response: "status": status, "status_info": status_info, "data": data - }), code + } + ), code diff --git a/app.py b/app.py index 6b2b12e9..61f8dfa5 100644 --- a/app.py +++ b/app.py @@ -14,6 +14,7 @@ from admin.admin import admin_bp, set_theme_loader from api import api_bp +from cache_config import cache from datarequest.datarequest import datarequest_bp from deposit.deposit import deposit_bp from fileviewer.fileviewer import fileviewer_bp @@ -28,6 +29,7 @@ from util import get_validated_static_path, log_error from vault.vault import vault_bp + app = Flask(__name__, static_folder='assets') app.json.sort_keys = False @@ -123,6 +125,9 @@ def load_admin_setting() -> Dict[str, Any]: # Start Flask-Session Session(app) +# Initialize the cache. +cache.init_app(app) + # Start monitoring thread for extracting tech support information # Monitor signal file can be set to empty to completely disable monitor thread monitor_enabled: bool = app.config.get("MONITOR_SIGNAL_FILE", "/var/www/yoda/show-tech.sig") != "" diff --git a/cache_config.py b/cache_config.py new file mode 100644 index 00000000..f93a17a4 --- /dev/null +++ b/cache_config.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 + +__copyright__ = 'Copyright (c) 2024-2025, Utrecht University' +__license__ = 'GPLv3, see LICENSE' + +import hashlib +import json +from concurrent.futures import ThreadPoolExecutor +from functools import wraps +from typing import Callable, List, Optional + +from flask import current_app as app, g, request, session +from flask_caching import Cache + +import api +from util import log_error + +# Create a global ThreadPoolExecutor. +executor = ThreadPoolExecutor(max_workers=2) + +# Configuration for caching. +config = { + "CACHE_TYPE": "RedisCache", + "CACHE_KEY_PREFIX": "yoda_portal_cache_", + "CACHE_DEFAULT_TIMEOUT": 60 * 60, + "CACHE_REDIS_HOST": "localhost", + "CACHE_REDIS_PORT": 6379, + "CACHE_REDIS_DB": 1, + "CACHE_OPTIONS": {"socket_timeout": 2, "retry_on_timeout": True}, +} + +# API cache timeouts configuration. +API_CACHE_TIMEOUTS = { + "group_data": {"timeout": 3600}, + "notifications_load": {"timeout": 120}, + "resource_browse_group_data": {"timeout": 3600}, + "resource_category_stats": {"timeout": 3600}, + "resource_monthly_category_stats": {"timeout": 3600}, + "schema_get_schemas": {"timeout": 3600}, + "settings_load": {"timeout": 3600}, + "token_load": {"timeout": 3600}, + "vault_get_publication_terms": {"timeout": 3600}, + "vault_preservable_formats_lists": {"timeout": 3600}, +} + +# API cache default parameters configuration. +API_CACHE_PARAMS = { + "resource_browse_group_data": { + "default_params": {"offset": 0, "limit": 200, "sort_order": "asc", "sort_on": "name", "search_groups": ""} + }, +} + +# API cache clear configuration. +API_CACHE_CLEAR = { + "admin_save_settings": {"type": "global", "endpoints": ["vault_get_publication_terms", + "vault_preservable_formats_lists"]}, + "group_create": {"type": "global", "endpoints": ["group_data"]}, + "group_update": {"type": "global", "endpoints": ["group_data"]}, + "group_delete": {"type": "global", "endpoints": ["group_data"]}, + "group_user_add": {"type": "global", "endpoints": ["group_data"]}, + "group_user_update_role": {"type": "global", "endpoints": ["group_data"]}, + "group_remove_user_from_group": {"type": "global", "endpoints": ["group_data"]}, + "notifications_dismiss": {"type": "user", "endpoints": ["notifications_load"]}, + "notifications_dismiss_all": {"type": "user", "endpoints": ["notifications_load"]}, + "settings_save": {"type": "user", "endpoints": ["settings_load"]}, + "token_generate": {"type": "user", "endpoints": ["token_load"]}, + "token_delete": {"type": "user", "endpoints": ["token_load"]}, +} + + +def authenticated() -> bool: + """Check if the user is authenticated. + + :returns: True if the user is authenticated, False otherwise + """ + return g.get("user") is not None and g.get("irods") is not None + + +def get_user_identifier() -> str: + """Get user identifier generated from username, return unredacted username in development environments. + + :returns: User identifier + """ + if authenticated(): + user = g.get('user') + is_development = app.config.get("YODA_ENVIRONMENT") == "development" + return user if is_development else hashlib.shake_256(user.encode("utf-8")).hexdigest(20) + else: + return "unauthenticated" + + +def make_key(api_key: Optional[str] = None) -> str: + """Generate a cache key based on the request and authentication status. + + :param api_key: Optional custom key identifying API endpoint. If None, defaults to request endpoint and method + + :returns: A string representing the cache key + """ + if api_key is None: + key = f"view-{session.sid}-{request.endpoint}_{request.method}" + else: + key = f"api-{api_key}" + + user_identifier = get_user_identifier() + + return f"{user_identifier}-{key}" + + +def get_api_cache_functions() -> list: + """Return a list of function names from the API_CACHE_TIMEOUTS configuration.""" + return list(API_CACHE_TIMEOUTS.keys()) + + +def get_api_cache_timeout(fn: str) -> int: + """Retrieve the cache timeout for a specific API function. + + :param fn: The name of the API function + + :returns: The cache timeout in seconds, or 0 if not found + """ + return API_CACHE_TIMEOUTS.get(fn, {"timeout": 0})["timeout"] + + +def filter_cache_keys(substring: str) -> List[str]: + """Retrieve cache keys that contain a specified substring. + + :param substring: The substring to search for in cache key names + + :returns: A list of matching cache keys + """ + prefix = cache.cache.key_prefix + keys = cache.cache._write_client.keys(f"*{substring}*") + return [bs.decode("utf-8")[len(prefix):] for bs in keys] + + +def clear_view_cache_keys(all_users: bool = False) -> None: + """Clear view cache keys associated with the current user or all users. + + :param all_users: If True, clear view cache keys for all users + """ + if all_users: + filter_key = "-view-" + else: + user_identifier = get_user_identifier() + filter_key = f"{user_identifier}-view-" + + # Get the keys to delete. + keys_to_delete = list(filter_cache_keys(filter_key)) + # Attempt to delete the keys from the cache. + try: + cache.delete_many(*keys_to_delete) + except Exception: + log_error(f"Error deleting view cache keys: {keys_to_delete}", True) + + +def clear_api_cache_keys(fn: str) -> None: + """Clear API cache keys associated with a specified API function. + + :param fn: The name of the API function + """ + if fn in API_CACHE_CLEAR: + for endpoint in API_CACHE_CLEAR[fn]["endpoints"]: + # Determine the key to filter based on the type. + if API_CACHE_CLEAR[fn]["type"] == "user": + filter_key = make_key(endpoint) + else: + filter_key = endpoint + + # Get the keys to delete. + keys_to_delete = list(filter_cache_keys(filter_key)) + # Attempt to delete the keys from the cache. + try: + cache.delete_many(*keys_to_delete) + except Exception: + log_error(f"Error deleting API cache keys: {keys_to_delete}", True) + + +def populate_api_cache(fn: str, user: str, irods: str, session_id: str) -> None: + """Function to prepopulate the API cache for specified function.""" + # Set and session context. + g.user = user + g.irods = irods + session.sid = session_id + + if authenticated(): + data = {} + if fn in API_CACHE_PARAMS: + data = API_CACHE_PARAMS[fn]["default_params"] + + params = json.dumps(data) + encoded_params = hashlib.shake_256(params.encode("utf-8")).hexdigest(20) + cache_key = make_key(f"{fn}-{encoded_params}") + if cache.get(cache_key) is None: + try: + api.call(fn, data) + except Exception as e: + log_error(f"Error prepopulating cache {fn}: {e}", True) + + +def cache_view() -> Callable: + """Custom decorator to conditionally apply caching to views.""" + def decorator(f: Callable) -> Callable: + @wraps(f) + def wrapped(*args: str, **kwargs: int) -> Callable: + if cache: + return cache.cached(make_cache_key=make_key)(f)(*args, **kwargs) + else: + return f(*args, **kwargs) + return wrapped + return decorator + + +# Call the function +cache = Cache(config=config) diff --git a/deposit/deposit.py b/deposit/deposit.py index f0867533..7966da9b 100644 --- a/deposit/deposit.py +++ b/deposit/deposit.py @@ -24,6 +24,7 @@ import api import connman +from cache_config import cache_view deposit_bp = Blueprint('deposit_bp', __name__, template_folder='templates', @@ -41,6 +42,7 @@ @deposit_bp.route('/') @deposit_bp.route('/browse') +@cache_view() def index() -> Response: """Deposit overview""" return render_template('deposit/overview.html', @@ -166,6 +168,7 @@ def submit() -> Response: @deposit_bp.route('/thank-you') +@cache_view() def thankyou() -> Response: """Step 4: Thank you""" return render_template('deposit/thank-you.html') diff --git a/general/general.py b/general/general.py index 5242c76e..4db2ad53 100644 --- a/general/general.py +++ b/general/general.py @@ -6,6 +6,7 @@ from flask import Blueprint, redirect, render_template, request, Response, session, url_for from flask_wtf.csrf import CSRFError +from cache_config import cache_view from util import log_error general_bp = Blueprint('general_bp', __name__, @@ -15,6 +16,7 @@ @general_bp.route('/') +@cache_view() def index() -> Response: return render_template('index.html') diff --git a/general/templates/general/base.html b/general/templates/general/base.html index bf919047..b935f7ab 100644 --- a/general/templates/general/base.html +++ b/general/templates/general/base.html @@ -29,6 +29,7 @@ Yoda.user = { username: '{{ g.user }}', }; + Yoda.notifications() {% endif %} {% if g.settings %} Yoda.settings = {{ g.settings | tojson}} @@ -38,6 +39,7 @@ {% else %} Yoda.set_color_mode_auto() {% endif %} + {% block scripts %}{% endblock scripts %} {% endblock head %} diff --git a/general/templates/general/user.html b/general/templates/general/user.html index 865de333..ccdd9b8b 100644 --- a/general/templates/general/user.html +++ b/general/templates/general/user.html @@ -1,11 +1,11 @@ {% if g.user %}