diff --git a/src/server/_common.py b/src/server/_common.py index 56d4c38ec..8257d12f8 100644 --- a/src/server/_common.py +++ b/src/server/_common.py @@ -11,7 +11,7 @@ from ._config import SECRET, REVERSE_PROXY_DEPTH from ._db import engine from ._exceptions import DatabaseErrorException, EpiDataException -from ._security import current_user, _is_public_route, resolve_auth_token, show_no_api_key_warning, update_key_last_time_used, ERROR_MSG_INVALID_KEY +from ._security import current_user, _is_public_route, resolve_auth_token, update_key_last_time_used, ERROR_MSG_INVALID_KEY app = Flask("EpiData", static_url_path="") @@ -127,11 +127,10 @@ def before_request_execute(): user_id=(user and user.id) ) - if not show_no_api_key_warning(): - if not _is_public_route() and api_key and not user: - # if this is a privleged endpoint, and an api key was given but it does not look up to a user, raise exception: - get_structured_logger("server_api").info("bad api key used", api_key=api_key) - raise Unauthorized(ERROR_MSG_INVALID_KEY) + if not _is_public_route() and api_key and not user: + # if this is a privleged endpoint, and an api key was given but it does not look up to a user, raise exception: + get_structured_logger("server_api").info("bad api key used", api_key=api_key) + raise Unauthorized(ERROR_MSG_INVALID_KEY) if request.path.startswith("/lib"): # files served from 'lib' directory don't need the database, so we can exit this early... diff --git a/src/server/_config.py b/src/server/_config.py index 168512a3d..3aad57e51 100644 --- a/src/server/_config.py +++ b/src/server/_config.py @@ -1,7 +1,6 @@ import json import os from datetime import date -from enum import Enum from dotenv import load_dotenv @@ -84,9 +83,6 @@ "cen9": ["AK", "CA", "HI", "OR", "WA"], } NATION_REGION = "nat" - -API_KEY_REQUIRED_STARTING_AT = date.fromisoformat(os.environ.get("API_KEY_REQUIRED_STARTING_AT", "2023-06-21")) -TEMPORARY_API_KEY = os.environ.get("TEMPORARY_API_KEY", "TEMP-API-KEY-EXPIRES-2023-06-28") # password needed for the admin application if not set the admin routes won't be available ADMIN_PASSWORD = os.environ.get("API_KEY_ADMIN_PASSWORD", "abc") # secret for the google form to give to the admin/register endpoint diff --git a/src/server/_db.py b/src/server/_db.py index 53e632cdf..7c3d8927e 100644 --- a/src/server/_db.py +++ b/src/server/_db.py @@ -1,8 +1,9 @@ from sqlalchemy import create_engine, MetaData from sqlalchemy.engine import Engine +import redis from sqlalchemy.orm import sessionmaker -from ._config import SQLALCHEMY_DATABASE_URI, SQLALCHEMY_DATABASE_URI_PRIMARY, SQLALCHEMY_ENGINE_OPTIONS +from ._config import SQLALCHEMY_DATABASE_URI, SQLALCHEMY_DATABASE_URI_PRIMARY, SQLALCHEMY_ENGINE_OPTIONS, REDIS_HOST, REDIS_PASSWORD # _db.py exists so that we dont have a circular dependency: @@ -20,4 +21,7 @@ Session = sessionmaker(bind=user_engine) +redis_conn_pool = redis.ConnectionPool(host=REDIS_HOST, password=REDIS_PASSWORD, max_connections=200) +redis_conn = redis.Redis(connection_pool=redis_conn_pool) + diff --git a/src/server/_limiter.py b/src/server/_limiter.py index 4bf72e05b..8bf6817ca 100644 --- a/src/server/_limiter.py +++ b/src/server/_limiter.py @@ -1,14 +1,14 @@ from delphi.epidata.server.endpoints.covidcast_utils.dashboard_signals import DashboardSignals -from flask import Response, request, make_response, jsonify +from flask import Response, request from flask_limiter import Limiter, HEADERS -from redis import Redis from werkzeug.exceptions import Unauthorized, TooManyRequests from ._common import app, get_real_ip_addr -from ._config import RATE_LIMIT, RATELIMIT_STORAGE_URL, REDIS_HOST, REDIS_PASSWORD +from ._db import redis_conn_pool +from ._config import RATE_LIMIT, RATELIMIT_STORAGE_URL from ._exceptions import ValidationFailedException from ._params import extract_dates, extract_integers, extract_strings -from ._security import _is_public_route, current_user, require_api_key, show_no_api_key_warning, resolve_auth_token, ERROR_MSG_RATE_LIMIT, ERROR_MSG_MULTIPLES +from ._security import _is_public_route, current_user, resolve_auth_token, ERROR_MSG_RATE_LIMIT, ERROR_MSG_MULTIPLES def deduct_on_success(response: Response) -> bool: @@ -91,6 +91,7 @@ def _resolve_tracking_key() -> str: _resolve_tracking_key, app=app, storage_uri=RATELIMIT_STORAGE_URL, + storage_options={"connection_pool": redis_conn_pool}, request_identifier=lambda: "EpidataLimiter", headers_enabled=True, header_name_mapping={ @@ -108,23 +109,8 @@ def ratelimit_handler(e): return TooManyRequests(ERROR_MSG_RATE_LIMIT) -def requests_left(): - r = Redis(host=REDIS_HOST, password=REDIS_PASSWORD) - allowed_count, period = RATE_LIMIT.split("/") - try: - remaining_count = int(allowed_count) - int( - r.get(f"LIMITER/{_resolve_tracking_key()}/EpidataLimiter/{allowed_count}/1/{period}") - ) - except TypeError: - return 1 - return remaining_count - - @limiter.request_filter def _no_rate_limit() -> bool: - if show_no_api_key_warning(): - # no rate limit in phase 0 - return True if _is_public_route(): # no rate limit for public routes return True @@ -132,14 +118,6 @@ def _no_rate_limit() -> bool: # no rate limit if user is registered return True - if not require_api_key(): - # we are in phase 1 or 2 - if requests_left() > 0: - # ...and user is below rate limit, we still want to record this query for the rate computation... - return False - # ...otherwise, they have exceeded the limit, but we still want to allow them through - return True - # phase 3 (full api-keys behavior) multiples = get_multiples_count(request) if multiples < 0: diff --git a/src/server/_printer.py b/src/server/_printer.py index 6e32d7d43..7bbf86d32 100644 --- a/src/server/_printer.py +++ b/src/server/_printer.py @@ -2,15 +2,12 @@ from io import StringIO from typing import Any, Dict, Iterable, List, Optional, Union -from flask import Response, jsonify, stream_with_context, request +from flask import Response, jsonify, stream_with_context from flask.json import dumps import orjson from ._config import MAX_RESULTS, MAX_COMPATIBILITY_RESULTS -# TODO: remove warnings after once we are past the API_KEY_REQUIRED_STARTING_AT date -from ._security import show_hard_api_key_warning, show_soft_api_key_warning, ROLLOUT_WARNING_RATE_LIMIT, ROLLOUT_WARNING_MULTIPLES, _ROLLOUT_WARNING_AD_FRAGMENT, PHASE_1_2_STOPGAP from ._common import is_compatibility_mode, log_info_with_request -from ._limiter import requests_left, get_multiples_count from delphi.epidata.common.logger import get_structured_logger @@ -25,15 +22,7 @@ def print_non_standard(format: str, data): message = "no results" result = -2 else: - warning = "" - if show_hard_api_key_warning(): - if requests_left() == 0: - warning = f"{ROLLOUT_WARNING_RATE_LIMIT}" - if get_multiples_count(request) < 0: - warning = f"{warning} {ROLLOUT_WARNING_MULTIPLES}" - if requests_left() == 0 or get_multiples_count(request) < 0: - warning = f"{warning} {_ROLLOUT_WARNING_AD_FRAGMENT} {PHASE_1_2_STOPGAP}" - message = warning.strip() or "success" + message = "success" result = 1 if result == -1 and is_compatibility_mode(): return jsonify(dict(result=result, message=message)) @@ -126,40 +115,21 @@ class ClassicPrinter(APrinter): """ def _begin(self): - if is_compatibility_mode() and not show_hard_api_key_warning(): + if is_compatibility_mode(): return "{ " - r = '{ "epidata": [' - if show_hard_api_key_warning(): - warning = "" - if requests_left() == 0: - warning = f"{warning} {ROLLOUT_WARNING_RATE_LIMIT}" - if get_multiples_count(request) < 0: - warning = f"{warning} {ROLLOUT_WARNING_MULTIPLES}" - if requests_left() == 0 or get_multiples_count(request) < 0: - warning = f"{warning} {_ROLLOUT_WARNING_AD_FRAGMENT} {PHASE_1_2_STOPGAP}" - if warning != "": - return f'{r} "{warning.strip()}",' - return r + return '{ "epidata": [' def _format_row(self, first: bool, row: Dict): - if first and is_compatibility_mode() and not show_hard_api_key_warning(): + if first and is_compatibility_mode(): sep = b'"epidata": [' else: sep = b"," if not first else b"" return sep + orjson.dumps(row) def _end(self): - warning = "" - if show_soft_api_key_warning(): - if requests_left() == 0: - warning = f"{warning} {ROLLOUT_WARNING_RATE_LIMIT}" - if get_multiples_count(request) < 0: - warning = f"{warning} {ROLLOUT_WARNING_MULTIPLES}" - if requests_left() == 0 or get_multiples_count(request) < 0: - warning = f"{warning} {_ROLLOUT_WARNING_AD_FRAGMENT} {PHASE_1_2_STOPGAP}" - message = warning.strip() or "success" + message = "success" prefix = "], " - if self.count == 0 and is_compatibility_mode() and not show_hard_api_key_warning(): + if self.count == 0 and is_compatibility_mode(): # no array to end prefix = "" @@ -193,7 +163,7 @@ def _format_row(self, first: bool, row: Dict): self._tree[group].append(row) else: self._tree[group] = [row] - if first and is_compatibility_mode() and not show_hard_api_key_warning(): + if first and is_compatibility_mode(): return b'"epidata": [' return None @@ -204,10 +174,7 @@ def _end(self): tree = orjson.dumps(self._tree) self._tree = dict() r = super(ClassicTreePrinter, self)._end() - r = tree + r - if show_hard_api_key_warning() and (requests_left() == 0 or get_multiples_count(request) < 0): - r = b", " + r - return r + return tree + r class CSVPrinter(APrinter): @@ -239,17 +206,6 @@ def _format_row(self, first: bool, row: Dict): columns = list(row.keys()) self._writer = DictWriter(self._stream, columns, lineterminator="\n") self._writer.writeheader() - if show_hard_api_key_warning(): - warning = "" - if requests_left() == 0: - warning = f"{warning} {ROLLOUT_WARNING_RATE_LIMIT}" - if get_multiples_count(request) < 0: - warning = f"{warning} {ROLLOUT_WARNING_MULTIPLES}" - if requests_left() == 0 or get_multiples_count(request) < 0: - warning = f"{warning} {_ROLLOUT_WARNING_AD_FRAGMENT} {PHASE_1_2_STOPGAP}" - if warning.strip() != "": - self._writer.writerow({columns[0]: warning}) - self._writer.writerow(row) # remove the stream content to print just one line at a time @@ -270,18 +226,7 @@ class JSONPrinter(APrinter): """ def _begin(self): - r = b"[" - if show_hard_api_key_warning(): - warning = "" - if requests_left() == 0: - warning = f"{warning} {ROLLOUT_WARNING_RATE_LIMIT}" - if get_multiples_count(request) < 0: - warning = f"{warning} {ROLLOUT_WARNING_MULTIPLES}" - if requests_left() == 0 or get_multiples_count(request) < 0: - warning = f"{warning} {_ROLLOUT_WARNING_AD_FRAGMENT} {PHASE_1_2_STOPGAP}" - if warning.strip() != "": - r = b'["' + bytes(warning, "utf-8") + b'",' - return r + return b"[" def _format_row(self, first: bool, row: Dict): sep = b"," if not first else b"" @@ -299,19 +244,6 @@ class JSONLPrinter(APrinter): def make_response(self, gen): return Response(gen, mimetype=" text/plain; charset=utf8") - def _begin(self): - if show_hard_api_key_warning(): - warning = "" - if requests_left() == 0: - warning = f"{warning} {ROLLOUT_WARNING_RATE_LIMIT}" - if get_multiples_count(request) < 0: - warning = f"{warning} {ROLLOUT_WARNING_MULTIPLES}" - if requests_left() == 0 or get_multiples_count(request) < 0: - warning = f"{warning} {_ROLLOUT_WARNING_AD_FRAGMENT} {PHASE_1_2_STOPGAP}" - if warning.strip() != "": - return bytes(warning, "utf-8") + b"\n" - return None - def _format_row(self, first: bool, row: Dict): # each line is a JSON file with a new line to separate them return orjson.dumps(row, option=orjson.OPT_APPEND_NEWLINE) diff --git a/src/server/_security.py b/src/server/_security.py index 761d088c3..545e443c4 100644 --- a/src/server/_security.py +++ b/src/server/_security.py @@ -2,33 +2,17 @@ from functools import wraps from typing import Optional, cast -import redis from delphi.epidata.common.logger import get_structured_logger from flask import g, request from werkzeug.exceptions import Unauthorized from werkzeug.local import LocalProxy from ._config import ( - API_KEY_REQUIRED_STARTING_AT, - REDIS_HOST, - REDIS_PASSWORD, API_KEY_REGISTRATION_FORM_LINK_LOCAL, - TEMPORARY_API_KEY, URL_PREFIX, ) from .admin.models import User, UserRole - -API_KEY_HARD_WARNING = API_KEY_REQUIRED_STARTING_AT - timedelta(days=14) -API_KEY_SOFT_WARNING = API_KEY_HARD_WARNING - timedelta(days=14) - -# rollout warning messages -ROLLOUT_WARNING_RATE_LIMIT = "This request exceeded the rate limit on anonymous requests, which will be enforced starting {}.".format(API_KEY_REQUIRED_STARTING_AT) -ROLLOUT_WARNING_MULTIPLES = "This request exceeded the anonymous limit on selected multiples, which will be enforced starting {}.".format(API_KEY_REQUIRED_STARTING_AT) -_ROLLOUT_WARNING_AD_FRAGMENT = "To be exempt from this limit, authenticate your requests with a free API key, now available at {}.".format(API_KEY_REGISTRATION_FORM_LINK_LOCAL) - -PHASE_1_2_STOPGAP = ( - "A temporary public key `{}` is available for use between now and {} to give you time to register or adapt your requests without this message continuing to break your systems." -).format(TEMPORARY_API_KEY, (API_KEY_REQUIRED_STARTING_AT + timedelta(days=7))) +from ._db import redis_conn # steady-state error messages @@ -53,31 +37,6 @@ def resolve_auth_token() -> Optional[str]: return auth_header[len("Bearer ") :] return None - -def show_no_api_key_warning() -> bool: - # aka "phase 0" - n = date.today() - return not current_user and n < API_KEY_SOFT_WARNING - - -def show_soft_api_key_warning() -> bool: - # aka "phase 1" - n = date.today() - return not current_user and API_KEY_SOFT_WARNING <= n < API_KEY_HARD_WARNING - - -def show_hard_api_key_warning() -> bool: - # aka "phase 2" - n = date.today() - return not current_user and API_KEY_HARD_WARNING <= n < API_KEY_REQUIRED_STARTING_AT - - -def require_api_key() -> bool: - # aka "phase 3" - n = date.today() - return API_KEY_REQUIRED_STARTING_AT <= n - - def _get_current_user(): if "user" not in g: api_key = resolve_auth_token() @@ -129,5 +88,4 @@ def update_key_last_time_used(user): return if user: # update last usage for this user's api key to "now()" - r = redis.Redis(host=REDIS_HOST, password=REDIS_PASSWORD) - r.set(f"LAST_USED/{user.api_key}", datetime.strftime(datetime.now(), "%Y-%m-%d")) + redis_conn.set(f"LAST_USED/{user.api_key}", datetime.strftime(datetime.now(), "%Y-%m-%d"))