Skip to content

Commit

Permalink
Drop phase handling
Browse files Browse the repository at this point in the history
  • Loading branch information
dmytrotsko committed Jun 26, 2023
1 parent 13fcfe3 commit 86a1898
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 150 deletions.
11 changes: 5 additions & 6 deletions src/server/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ._config import SECRET, REVERSE_PROXY_DEPTH
from ._db import engine
from ._exceptions import DatabaseErrorException, EpiDataException
from ._security import current_user, _is_public_route, resolve_auth_token, show_no_api_key_warning, update_key_last_time_used, ERROR_MSG_INVALID_KEY
from ._security import current_user, _is_public_route, resolve_auth_token, update_key_last_time_used, ERROR_MSG_INVALID_KEY


app = Flask("EpiData", static_url_path="")
Expand Down Expand Up @@ -127,11 +127,10 @@ def before_request_execute():
user_id=(user and user.id)
)

if not show_no_api_key_warning():
if not _is_public_route() and api_key and not user:
# if this is a privleged endpoint, and an api key was given but it does not look up to a user, raise exception:
get_structured_logger("server_api").info("bad api key used", api_key=api_key)
raise Unauthorized(ERROR_MSG_INVALID_KEY)
if not _is_public_route() and api_key and not user:
# if this is a privleged endpoint, and an api key was given but it does not look up to a user, raise exception:
get_structured_logger("server_api").info("bad api key used", api_key=api_key)
raise Unauthorized(ERROR_MSG_INVALID_KEY)

if request.path.startswith("/lib"):
# files served from 'lib' directory don't need the database, so we can exit this early...
Expand Down
2 changes: 0 additions & 2 deletions src/server/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,6 @@
}
NATION_REGION = "nat"

API_KEY_REQUIRED_STARTING_AT = date.fromisoformat(os.environ.get("API_KEY_REQUIRED_STARTING_AT", "2023-06-21"))
TEMPORARY_API_KEY = os.environ.get("TEMPORARY_API_KEY", "TEMP-API-KEY-EXPIRES-2023-06-28")
# password needed for the admin application if not set the admin routes won't be available
ADMIN_PASSWORD = os.environ.get("API_KEY_ADMIN_PASSWORD", "abc")
# secret for the google form to give to the admin/register endpoint
Expand Down
27 changes: 2 additions & 25 deletions src/server/_limiter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from delphi.epidata.server.endpoints.covidcast_utils.dashboard_signals import DashboardSignals
from flask import Response, request, make_response, jsonify
from flask import Response, request
from flask_limiter import Limiter, HEADERS
from redis import Redis
from werkzeug.exceptions import Unauthorized, TooManyRequests
Expand All @@ -8,7 +8,7 @@
from ._config import RATE_LIMIT, RATELIMIT_STORAGE_URL, REDIS_HOST, REDIS_PASSWORD
from ._exceptions import ValidationFailedException
from ._params import extract_dates, extract_integers, extract_strings
from ._security import _is_public_route, current_user, require_api_key, show_no_api_key_warning, resolve_auth_token, ERROR_MSG_RATE_LIMIT, ERROR_MSG_MULTIPLES
from ._security import _is_public_route, current_user, resolve_auth_token, ERROR_MSG_RATE_LIMIT, ERROR_MSG_MULTIPLES


def deduct_on_success(response: Response) -> bool:
Expand Down Expand Up @@ -108,38 +108,15 @@ def ratelimit_handler(e):
return TooManyRequests(ERROR_MSG_RATE_LIMIT)


def requests_left():
r = Redis(host=REDIS_HOST, password=REDIS_PASSWORD)
allowed_count, period = RATE_LIMIT.split("/")
try:
remaining_count = int(allowed_count) - int(
r.get(f"LIMITER/{_resolve_tracking_key()}/EpidataLimiter/{allowed_count}/1/{period}")
)
except TypeError:
return 1
return remaining_count


@limiter.request_filter
def _no_rate_limit() -> bool:
if show_no_api_key_warning():
# no rate limit in phase 0
return True
if _is_public_route():
# no rate limit for public routes
return True
if current_user:
# no rate limit if user is registered
return True

if not require_api_key():
# we are in phase 1 or 2
if requests_left() > 0:
# ...and user is below rate limit, we still want to record this query for the rate computation...
return False
# ...otherwise, they have exceeded the limit, but we still want to allow them through
return True

# phase 3 (full api-keys behavior)
multiples = get_multiples_count(request)
if multiples < 0:
Expand Down
90 changes: 11 additions & 79 deletions src/server/_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@
from io import StringIO
from typing import Any, Dict, Iterable, List, Optional, Union

from flask import Response, jsonify, stream_with_context, request
from flask import Response, jsonify, stream_with_context
from flask.json import dumps
import orjson

from ._config import MAX_RESULTS, MAX_COMPATIBILITY_RESULTS
# TODO: remove warnings after once we are past the API_KEY_REQUIRED_STARTING_AT date
from ._security import show_hard_api_key_warning, show_soft_api_key_warning, ROLLOUT_WARNING_RATE_LIMIT, ROLLOUT_WARNING_MULTIPLES, _ROLLOUT_WARNING_AD_FRAGMENT, PHASE_1_2_STOPGAP
from ._common import is_compatibility_mode, log_info_with_request
from ._limiter import requests_left, get_multiples_count
from delphi.epidata.common.logger import get_structured_logger


Expand All @@ -25,15 +22,7 @@ def print_non_standard(format: str, data):
message = "no results"
result = -2
else:
warning = ""
if show_hard_api_key_warning():
if requests_left() == 0:
warning = f"{ROLLOUT_WARNING_RATE_LIMIT}"
if get_multiples_count(request) < 0:
warning = f"{warning} {ROLLOUT_WARNING_MULTIPLES}"
if requests_left() == 0 or get_multiples_count(request) < 0:
warning = f"{warning} {_ROLLOUT_WARNING_AD_FRAGMENT} {PHASE_1_2_STOPGAP}"
message = warning.strip() or "success"
message = "success"
result = 1
if result == -1 and is_compatibility_mode():
return jsonify(dict(result=result, message=message))
Expand Down Expand Up @@ -126,40 +115,21 @@ class ClassicPrinter(APrinter):
"""

def _begin(self):
if is_compatibility_mode() and not show_hard_api_key_warning():
if is_compatibility_mode():
return "{ "
r = '{ "epidata": ['
if show_hard_api_key_warning():
warning = ""
if requests_left() == 0:
warning = f"{warning} {ROLLOUT_WARNING_RATE_LIMIT}"
if get_multiples_count(request) < 0:
warning = f"{warning} {ROLLOUT_WARNING_MULTIPLES}"
if requests_left() == 0 or get_multiples_count(request) < 0:
warning = f"{warning} {_ROLLOUT_WARNING_AD_FRAGMENT} {PHASE_1_2_STOPGAP}"
if warning != "":
return f'{r} "{warning.strip()}",'
return r
return '{ "epidata": ['

def _format_row(self, first: bool, row: Dict):
if first and is_compatibility_mode() and not show_hard_api_key_warning():
if first and is_compatibility_mode():
sep = b'"epidata": ['
else:
sep = b"," if not first else b""
return sep + orjson.dumps(row)

def _end(self):
warning = ""
if show_soft_api_key_warning():
if requests_left() == 0:
warning = f"{warning} {ROLLOUT_WARNING_RATE_LIMIT}"
if get_multiples_count(request) < 0:
warning = f"{warning} {ROLLOUT_WARNING_MULTIPLES}"
if requests_left() == 0 or get_multiples_count(request) < 0:
warning = f"{warning} {_ROLLOUT_WARNING_AD_FRAGMENT} {PHASE_1_2_STOPGAP}"
message = warning.strip() or "success"
message = "success"
prefix = "], "
if self.count == 0 and is_compatibility_mode() and not show_hard_api_key_warning():
if self.count == 0 and is_compatibility_mode():
# no array to end
prefix = ""

Expand Down Expand Up @@ -193,7 +163,7 @@ def _format_row(self, first: bool, row: Dict):
self._tree[group].append(row)
else:
self._tree[group] = [row]
if first and is_compatibility_mode() and not show_hard_api_key_warning():
if first and is_compatibility_mode():
return b'"epidata": ['
return None

Expand All @@ -204,10 +174,7 @@ def _end(self):
tree = orjson.dumps(self._tree)
self._tree = dict()
r = super(ClassicTreePrinter, self)._end()
r = tree + r
if show_hard_api_key_warning() and (requests_left() == 0 or get_multiples_count(request) < 0):
r = b", " + r
return r
return tree + r


class CSVPrinter(APrinter):
Expand Down Expand Up @@ -239,17 +206,6 @@ def _format_row(self, first: bool, row: Dict):
columns = list(row.keys())
self._writer = DictWriter(self._stream, columns, lineterminator="\n")
self._writer.writeheader()
if show_hard_api_key_warning():
warning = ""
if requests_left() == 0:
warning = f"{warning} {ROLLOUT_WARNING_RATE_LIMIT}"
if get_multiples_count(request) < 0:
warning = f"{warning} {ROLLOUT_WARNING_MULTIPLES}"
if requests_left() == 0 or get_multiples_count(request) < 0:
warning = f"{warning} {_ROLLOUT_WARNING_AD_FRAGMENT} {PHASE_1_2_STOPGAP}"
if warning.strip() != "":
self._writer.writerow({columns[0]: warning})

self._writer.writerow(row)

# remove the stream content to print just one line at a time
Expand All @@ -270,18 +226,7 @@ class JSONPrinter(APrinter):
"""

def _begin(self):
r = b"["
if show_hard_api_key_warning():
warning = ""
if requests_left() == 0:
warning = f"{warning} {ROLLOUT_WARNING_RATE_LIMIT}"
if get_multiples_count(request) < 0:
warning = f"{warning} {ROLLOUT_WARNING_MULTIPLES}"
if requests_left() == 0 or get_multiples_count(request) < 0:
warning = f"{warning} {_ROLLOUT_WARNING_AD_FRAGMENT} {PHASE_1_2_STOPGAP}"
if warning.strip() != "":
r = b'["' + bytes(warning, "utf-8") + b'",'
return r
return b"["

def _format_row(self, first: bool, row: Dict):
sep = b"," if not first else b""
Expand All @@ -299,19 +244,6 @@ class JSONLPrinter(APrinter):
def make_response(self, gen):
return Response(gen, mimetype=" text/plain; charset=utf8")

def _begin(self):
if show_hard_api_key_warning():
warning = ""
if requests_left() == 0:
warning = f"{warning} {ROLLOUT_WARNING_RATE_LIMIT}"
if get_multiples_count(request) < 0:
warning = f"{warning} {ROLLOUT_WARNING_MULTIPLES}"
if requests_left() == 0 or get_multiples_count(request) < 0:
warning = f"{warning} {_ROLLOUT_WARNING_AD_FRAGMENT} {PHASE_1_2_STOPGAP}"
if warning.strip() != "":
return bytes(warning, "utf-8") + b"\n"
return None

def _format_row(self, first: bool, row: Dict):
# each line is a JSON file with a new line to separate them
return orjson.dumps(row, option=orjson.OPT_APPEND_NEWLINE)
Expand All @@ -334,4 +266,4 @@ def create_printer(format: str) -> APrinter:
return CSVPrinter()
if format == "jsonl":
return JSONLPrinter()
return ClassicPrinter()
return ClassicPrinter()
38 changes: 0 additions & 38 deletions src/server/_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,13 @@
from werkzeug.local import LocalProxy

from ._config import (
API_KEY_REQUIRED_STARTING_AT,
REDIS_HOST,
REDIS_PASSWORD,
API_KEY_REGISTRATION_FORM_LINK_LOCAL,
TEMPORARY_API_KEY,
URL_PREFIX,
)
from .admin.models import User, UserRole

API_KEY_HARD_WARNING = API_KEY_REQUIRED_STARTING_AT - timedelta(days=14)
API_KEY_SOFT_WARNING = API_KEY_HARD_WARNING - timedelta(days=14)

# rollout warning messages
ROLLOUT_WARNING_RATE_LIMIT = "This request exceeded the rate limit on anonymous requests, which will be enforced starting {}.".format(API_KEY_REQUIRED_STARTING_AT)
ROLLOUT_WARNING_MULTIPLES = "This request exceeded the anonymous limit on selected multiples, which will be enforced starting {}.".format(API_KEY_REQUIRED_STARTING_AT)
_ROLLOUT_WARNING_AD_FRAGMENT = "To be exempt from this limit, authenticate your requests with a free API key, now available at {}.".format(API_KEY_REGISTRATION_FORM_LINK_LOCAL)

PHASE_1_2_STOPGAP = (
"A temporary public key `{}` is available for use between now and {} to give you time to register or adapt your requests without this message continuing to break your systems."
).format(TEMPORARY_API_KEY, (API_KEY_REQUIRED_STARTING_AT + timedelta(days=7)))


# steady-state error messages
ERROR_MSG_RATE_LIMIT = "Rate limit exceeded for anonymous queries. To remove this limit, register a free API key at {}".format(API_KEY_REGISTRATION_FORM_LINK_LOCAL)
Expand All @@ -54,30 +40,6 @@ def resolve_auth_token() -> Optional[str]:
return None


def show_no_api_key_warning() -> bool:
# aka "phase 0"
n = date.today()
return not current_user and n < API_KEY_SOFT_WARNING


def show_soft_api_key_warning() -> bool:
# aka "phase 1"
n = date.today()
return not current_user and API_KEY_SOFT_WARNING <= n < API_KEY_HARD_WARNING


def show_hard_api_key_warning() -> bool:
# aka "phase 2"
n = date.today()
return not current_user and API_KEY_HARD_WARNING <= n < API_KEY_REQUIRED_STARTING_AT


def require_api_key() -> bool:
# aka "phase 3"
n = date.today()
return API_KEY_REQUIRED_STARTING_AT <= n


def _get_current_user():
if "user" not in g:
api_key = resolve_auth_token()
Expand Down

0 comments on commit 86a1898

Please sign in to comment.