Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AK-107 Drop phase handling #1204

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions src/server/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ._config import SECRET, REVERSE_PROXY_DEPTH
from ._db import engine
from ._exceptions import DatabaseErrorException, EpiDataException
from ._security import current_user, _is_public_route, resolve_auth_token, show_no_api_key_warning, update_key_last_time_used, ERROR_MSG_INVALID_KEY
from ._security import current_user, _is_public_route, resolve_auth_token, update_key_last_time_used, ERROR_MSG_INVALID_KEY


app = Flask("EpiData", static_url_path="")
Expand Down Expand Up @@ -127,11 +127,10 @@ def before_request_execute():
user_id=(user and user.id)
)

if not show_no_api_key_warning():
if not _is_public_route() and api_key and not user:
# if this is a privleged endpoint, and an api key was given but it does not look up to a user, raise exception:
get_structured_logger("server_api").info("bad api key used", api_key=api_key)
raise Unauthorized(ERROR_MSG_INVALID_KEY)
if not _is_public_route() and api_key and not user:
# if this is a privleged endpoint, and an api key was given but it does not look up to a user, raise exception:
get_structured_logger("server_api").info("bad api key used", api_key=api_key)
raise Unauthorized(ERROR_MSG_INVALID_KEY)

if request.path.startswith("/lib"):
# files served from 'lib' directory don't need the database, so we can exit this early...
Expand Down
4 changes: 0 additions & 4 deletions src/server/_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
import os
from datetime import date
from enum import Enum

from dotenv import load_dotenv

Expand Down Expand Up @@ -84,9 +83,6 @@
"cen9": ["AK", "CA", "HI", "OR", "WA"],
}
NATION_REGION = "nat"

API_KEY_REQUIRED_STARTING_AT = date.fromisoformat(os.environ.get("API_KEY_REQUIRED_STARTING_AT", "2023-06-21"))
TEMPORARY_API_KEY = os.environ.get("TEMPORARY_API_KEY", "TEMP-API-KEY-EXPIRES-2023-06-28")
# password needed for the admin application if not set the admin routes won't be available
ADMIN_PASSWORD = os.environ.get("API_KEY_ADMIN_PASSWORD", "abc")
# secret for the google form to give to the admin/register endpoint
Expand Down
6 changes: 5 additions & 1 deletion src/server/_db.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from sqlalchemy import create_engine, MetaData
from sqlalchemy.engine import Engine
import redis
from sqlalchemy.orm import sessionmaker

from ._config import SQLALCHEMY_DATABASE_URI, SQLALCHEMY_DATABASE_URI_PRIMARY, SQLALCHEMY_ENGINE_OPTIONS
from ._config import SQLALCHEMY_DATABASE_URI, SQLALCHEMY_DATABASE_URI_PRIMARY, SQLALCHEMY_ENGINE_OPTIONS, REDIS_HOST, REDIS_PASSWORD


# _db.py exists so that we dont have a circular dependency:
Expand All @@ -20,4 +21,7 @@

Session = sessionmaker(bind=user_engine)

redis_conn_pool = redis.ConnectionPool(host=REDIS_HOST, password=REDIS_PASSWORD, max_connections=200)
redis_conn = redis.Redis(connection_pool=redis_conn_pool)


32 changes: 5 additions & 27 deletions src/server/_limiter.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from delphi.epidata.server.endpoints.covidcast_utils.dashboard_signals import DashboardSignals
from flask import Response, request, make_response, jsonify
from flask import Response, request
from flask_limiter import Limiter, HEADERS
from redis import Redis
from werkzeug.exceptions import Unauthorized, TooManyRequests

from ._common import app, get_real_ip_addr
from ._config import RATE_LIMIT, RATELIMIT_STORAGE_URL, REDIS_HOST, REDIS_PASSWORD
from ._db import redis_conn_pool
from ._config import RATE_LIMIT, RATELIMIT_STORAGE_URL
from ._exceptions import ValidationFailedException
from ._params import extract_dates, extract_integers, extract_strings
from ._security import _is_public_route, current_user, require_api_key, show_no_api_key_warning, resolve_auth_token, ERROR_MSG_RATE_LIMIT, ERROR_MSG_MULTIPLES
from ._security import _is_public_route, current_user, resolve_auth_token, ERROR_MSG_RATE_LIMIT, ERROR_MSG_MULTIPLES


def deduct_on_success(response: Response) -> bool:
Expand Down Expand Up @@ -91,6 +91,7 @@ def _resolve_tracking_key() -> str:
_resolve_tracking_key,
app=app,
storage_uri=RATELIMIT_STORAGE_URL,
storage_options={"connection_pool": redis_conn_pool},
request_identifier=lambda: "EpidataLimiter",
headers_enabled=True,
header_name_mapping={
Expand All @@ -108,38 +109,15 @@ def ratelimit_handler(e):
return TooManyRequests(ERROR_MSG_RATE_LIMIT)


def requests_left():
r = Redis(host=REDIS_HOST, password=REDIS_PASSWORD)
allowed_count, period = RATE_LIMIT.split("/")
try:
remaining_count = int(allowed_count) - int(
r.get(f"LIMITER/{_resolve_tracking_key()}/EpidataLimiter/{allowed_count}/1/{period}")
)
except TypeError:
return 1
return remaining_count


@limiter.request_filter
def _no_rate_limit() -> bool:
if show_no_api_key_warning():
# no rate limit in phase 0
return True
if _is_public_route():
# no rate limit for public routes
return True
if current_user:
# no rate limit if user is registered
return True

if not require_api_key():
# we are in phase 1 or 2
if requests_left() > 0:
# ...and user is below rate limit, we still want to record this query for the rate computation...
return False
# ...otherwise, they have exceeded the limit, but we still want to allow them through
return True

# phase 3 (full api-keys behavior)
multiples = get_multiples_count(request)
if multiples < 0:
Expand Down
88 changes: 10 additions & 78 deletions src/server/_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@
from io import StringIO
from typing import Any, Dict, Iterable, List, Optional, Union

from flask import Response, jsonify, stream_with_context, request
from flask import Response, jsonify, stream_with_context
from flask.json import dumps
import orjson

from ._config import MAX_RESULTS, MAX_COMPATIBILITY_RESULTS
# TODO: remove warnings after once we are past the API_KEY_REQUIRED_STARTING_AT date
from ._security import show_hard_api_key_warning, show_soft_api_key_warning, ROLLOUT_WARNING_RATE_LIMIT, ROLLOUT_WARNING_MULTIPLES, _ROLLOUT_WARNING_AD_FRAGMENT, PHASE_1_2_STOPGAP
from ._common import is_compatibility_mode, log_info_with_request
from ._limiter import requests_left, get_multiples_count
from delphi.epidata.common.logger import get_structured_logger


Expand All @@ -25,15 +22,7 @@ def print_non_standard(format: str, data):
message = "no results"
result = -2
else:
warning = ""
if show_hard_api_key_warning():
if requests_left() == 0:
warning = f"{ROLLOUT_WARNING_RATE_LIMIT}"
if get_multiples_count(request) < 0:
warning = f"{warning} {ROLLOUT_WARNING_MULTIPLES}"
if requests_left() == 0 or get_multiples_count(request) < 0:
warning = f"{warning} {_ROLLOUT_WARNING_AD_FRAGMENT} {PHASE_1_2_STOPGAP}"
message = warning.strip() or "success"
message = "success"
result = 1
if result == -1 and is_compatibility_mode():
return jsonify(dict(result=result, message=message))
Expand Down Expand Up @@ -126,40 +115,21 @@ class ClassicPrinter(APrinter):
"""

def _begin(self):
if is_compatibility_mode() and not show_hard_api_key_warning():
if is_compatibility_mode():
return "{ "
r = '{ "epidata": ['
if show_hard_api_key_warning():
warning = ""
if requests_left() == 0:
warning = f"{warning} {ROLLOUT_WARNING_RATE_LIMIT}"
if get_multiples_count(request) < 0:
warning = f"{warning} {ROLLOUT_WARNING_MULTIPLES}"
if requests_left() == 0 or get_multiples_count(request) < 0:
warning = f"{warning} {_ROLLOUT_WARNING_AD_FRAGMENT} {PHASE_1_2_STOPGAP}"
if warning != "":
return f'{r} "{warning.strip()}",'
return r
return '{ "epidata": ['

def _format_row(self, first: bool, row: Dict):
if first and is_compatibility_mode() and not show_hard_api_key_warning():
if first and is_compatibility_mode():
sep = b'"epidata": ['
else:
sep = b"," if not first else b""
return sep + orjson.dumps(row)

def _end(self):
warning = ""
if show_soft_api_key_warning():
if requests_left() == 0:
warning = f"{warning} {ROLLOUT_WARNING_RATE_LIMIT}"
if get_multiples_count(request) < 0:
warning = f"{warning} {ROLLOUT_WARNING_MULTIPLES}"
if requests_left() == 0 or get_multiples_count(request) < 0:
warning = f"{warning} {_ROLLOUT_WARNING_AD_FRAGMENT} {PHASE_1_2_STOPGAP}"
message = warning.strip() or "success"
message = "success"
prefix = "], "
if self.count == 0 and is_compatibility_mode() and not show_hard_api_key_warning():
if self.count == 0 and is_compatibility_mode():
# no array to end
prefix = ""

Expand Down Expand Up @@ -193,7 +163,7 @@ def _format_row(self, first: bool, row: Dict):
self._tree[group].append(row)
else:
self._tree[group] = [row]
if first and is_compatibility_mode() and not show_hard_api_key_warning():
if first and is_compatibility_mode():
return b'"epidata": ['
return None

Expand All @@ -204,10 +174,7 @@ def _end(self):
tree = orjson.dumps(self._tree)
self._tree = dict()
r = super(ClassicTreePrinter, self)._end()
r = tree + r
if show_hard_api_key_warning() and (requests_left() == 0 or get_multiples_count(request) < 0):
r = b", " + r
return r
return tree + r


class CSVPrinter(APrinter):
Expand Down Expand Up @@ -239,17 +206,6 @@ def _format_row(self, first: bool, row: Dict):
columns = list(row.keys())
self._writer = DictWriter(self._stream, columns, lineterminator="\n")
self._writer.writeheader()
if show_hard_api_key_warning():
warning = ""
if requests_left() == 0:
warning = f"{warning} {ROLLOUT_WARNING_RATE_LIMIT}"
if get_multiples_count(request) < 0:
warning = f"{warning} {ROLLOUT_WARNING_MULTIPLES}"
if requests_left() == 0 or get_multiples_count(request) < 0:
warning = f"{warning} {_ROLLOUT_WARNING_AD_FRAGMENT} {PHASE_1_2_STOPGAP}"
if warning.strip() != "":
self._writer.writerow({columns[0]: warning})

self._writer.writerow(row)

# remove the stream content to print just one line at a time
Expand All @@ -270,18 +226,7 @@ class JSONPrinter(APrinter):
"""

def _begin(self):
r = b"["
if show_hard_api_key_warning():
warning = ""
if requests_left() == 0:
warning = f"{warning} {ROLLOUT_WARNING_RATE_LIMIT}"
if get_multiples_count(request) < 0:
warning = f"{warning} {ROLLOUT_WARNING_MULTIPLES}"
if requests_left() == 0 or get_multiples_count(request) < 0:
warning = f"{warning} {_ROLLOUT_WARNING_AD_FRAGMENT} {PHASE_1_2_STOPGAP}"
if warning.strip() != "":
r = b'["' + bytes(warning, "utf-8") + b'",'
return r
return b"["

def _format_row(self, first: bool, row: Dict):
sep = b"," if not first else b""
Expand All @@ -299,19 +244,6 @@ class JSONLPrinter(APrinter):
def make_response(self, gen):
return Response(gen, mimetype=" text/plain; charset=utf8")

def _begin(self):
if show_hard_api_key_warning():
warning = ""
if requests_left() == 0:
warning = f"{warning} {ROLLOUT_WARNING_RATE_LIMIT}"
if get_multiples_count(request) < 0:
warning = f"{warning} {ROLLOUT_WARNING_MULTIPLES}"
if requests_left() == 0 or get_multiples_count(request) < 0:
warning = f"{warning} {_ROLLOUT_WARNING_AD_FRAGMENT} {PHASE_1_2_STOPGAP}"
if warning.strip() != "":
return bytes(warning, "utf-8") + b"\n"
return None

def _format_row(self, first: bool, row: Dict):
# each line is a JSON file with a new line to separate them
return orjson.dumps(row, option=orjson.OPT_APPEND_NEWLINE)
Expand Down
46 changes: 2 additions & 44 deletions src/server/_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,17 @@
from functools import wraps
from typing import Optional, cast

import redis
from delphi.epidata.common.logger import get_structured_logger
from flask import g, request
from werkzeug.exceptions import Unauthorized
from werkzeug.local import LocalProxy

from ._config import (
API_KEY_REQUIRED_STARTING_AT,
REDIS_HOST,
REDIS_PASSWORD,
API_KEY_REGISTRATION_FORM_LINK_LOCAL,
TEMPORARY_API_KEY,
URL_PREFIX,
)
from .admin.models import User, UserRole

API_KEY_HARD_WARNING = API_KEY_REQUIRED_STARTING_AT - timedelta(days=14)
API_KEY_SOFT_WARNING = API_KEY_HARD_WARNING - timedelta(days=14)

# rollout warning messages
ROLLOUT_WARNING_RATE_LIMIT = "This request exceeded the rate limit on anonymous requests, which will be enforced starting {}.".format(API_KEY_REQUIRED_STARTING_AT)
ROLLOUT_WARNING_MULTIPLES = "This request exceeded the anonymous limit on selected multiples, which will be enforced starting {}.".format(API_KEY_REQUIRED_STARTING_AT)
_ROLLOUT_WARNING_AD_FRAGMENT = "To be exempt from this limit, authenticate your requests with a free API key, now available at {}.".format(API_KEY_REGISTRATION_FORM_LINK_LOCAL)

PHASE_1_2_STOPGAP = (
"A temporary public key `{}` is available for use between now and {} to give you time to register or adapt your requests without this message continuing to break your systems."
).format(TEMPORARY_API_KEY, (API_KEY_REQUIRED_STARTING_AT + timedelta(days=7)))
from ._db import redis_conn


# steady-state error messages
Expand All @@ -53,31 +37,6 @@ def resolve_auth_token() -> Optional[str]:
return auth_header[len("Bearer ") :]
return None


def show_no_api_key_warning() -> bool:
# aka "phase 0"
n = date.today()
return not current_user and n < API_KEY_SOFT_WARNING


def show_soft_api_key_warning() -> bool:
# aka "phase 1"
n = date.today()
return not current_user and API_KEY_SOFT_WARNING <= n < API_KEY_HARD_WARNING


def show_hard_api_key_warning() -> bool:
# aka "phase 2"
n = date.today()
return not current_user and API_KEY_HARD_WARNING <= n < API_KEY_REQUIRED_STARTING_AT


def require_api_key() -> bool:
# aka "phase 3"
n = date.today()
return API_KEY_REQUIRED_STARTING_AT <= n


def _get_current_user():
if "user" not in g:
api_key = resolve_auth_token()
Expand Down Expand Up @@ -129,5 +88,4 @@ def update_key_last_time_used(user):
return
if user:
# update last usage for this user's api key to "now()"
r = redis.Redis(host=REDIS_HOST, password=REDIS_PASSWORD)
r.set(f"LAST_USED/{user.api_key}", datetime.strftime(datetime.now(), "%Y-%m-%d"))
redis_conn.set(f"LAST_USED/{user.api_key}", datetime.strftime(datetime.now(), "%Y-%m-%d"))