Skip to content

Commit

Permalink
Refactor: replace Flask-Restful with own, simple implementation (#916)
Browse files Browse the repository at this point in the history
  • Loading branch information
psrok1 authored Feb 27, 2024
1 parent b198501 commit 3b11349
Show file tree
Hide file tree
Showing 28 changed files with 168 additions and 222 deletions.
11 changes: 7 additions & 4 deletions mwdb/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ class HashConverter(BaseConverter):
app.register_blueprint(static_blueprint)


@app.before_request
@api.blueprint.before_request
def assign_request_id():
g.request_id = token_hex(16)
g.request_start_time = datetime.utcnow()
Expand All @@ -152,7 +152,7 @@ def assign_request_id():
)


@app.after_request
@api.blueprint.after_request
def log_request(response):
if hasattr(g, "request_start_time"):
response_time = datetime.utcnow() - g.request_start_time
Expand Down Expand Up @@ -186,7 +186,7 @@ def log_request(response):
return response


@app.before_request
@api.blueprint.before_request
def require_auth():
if request.method == "OPTIONS":
return
Expand Down Expand Up @@ -221,7 +221,7 @@ def require_auth():
raise Forbidden("User has been disabled.")


@app.before_request
@api.blueprint.before_request
def apply_rate_limit():
apply_rate_limit_for_request()

Expand Down Expand Up @@ -412,3 +412,6 @@ def apply_rate_limit():
plugin_context = PluginAppContext()
with app.app_context():
load_plugins(plugin_context)

# Register blueprint
api.register()
91 changes: 0 additions & 91 deletions mwdb/core/apispec_utils.py

This file was deleted.

6 changes: 2 additions & 4 deletions mwdb/core/app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from flask import Blueprint, Flask
from flask import Flask
from werkzeug.middleware.proxy_fix import ProxyFix

from mwdb.core.config import app_config
Expand All @@ -7,9 +7,7 @@

app = Flask(__name__, static_folder=None)
app.config["MAX_CONTENT_LENGTH"] = app_config.mwdb.max_upload_size
api_blueprint = Blueprint("api", __name__, url_prefix="/api")
api = Service(app, api_blueprint)
app.register_blueprint(api_blueprint)
api = Service(app)

if app_config.mwdb.use_x_forwarded_for:
app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1)
5 changes: 4 additions & 1 deletion mwdb/core/rate_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ def apply_rate_limit_for_request() -> bool:
):
return False
# Split blueprint name and resource name from endpoint
_, resource_name = request.endpoint.split(".", 2)
if request.endpoint:
_, resource_name = request.endpoint.split(".", 2)
else:
resource_name = "None"
method = request.method.lower()
user = g.auth_user.login if g.auth_user is not None else request.remote_addr
# Limit keys from most specific to the least specific
Expand Down
209 changes: 131 additions & 78 deletions mwdb/core/service.py
Original file line number Diff line number Diff line change
@@ -1,97 +1,150 @@
import re
import sys
import textwrap
from functools import partial

from apispec import APISpec
from apispec import APISpec, yaml_utils
from apispec.ext.marshmallow import MarshmallowPlugin
from flask_restful import Api
from flask import Blueprint, Flask, jsonify, request
from flask.typing import ResponseReturnValue
from flask.views import MethodView
from sqlalchemy.exc import OperationalError
from werkzeug.exceptions import HTTPException, ServiceUnavailable
from werkzeug.exceptions import (
HTTPException,
InternalServerError,
MethodNotAllowed,
ServiceUnavailable,
)
from werkzeug.wrappers import Response

from mwdb.version import app_version

from . import log
from .apispec_utils import ApispecFlaskRestful


class Service(Api):
def __init__(self, flask_app, *args, **kwargs):
self.spec = self._create_spec()
self.flask_app = flask_app
super().__init__(*args, **kwargs)

def _init_app(self, app):
# I want to log exceptions on my own
def dont_log(*_, **__):
pass

app.log_exception = dont_log
if (
isinstance(app.handle_exception, partial)
and app.handle_exception.func is self.error_router
):
# Prevent double-initialization
return
super()._init_app(app)

def _create_spec(self):
spec = APISpec(
title="MWDB",
from .log import getLogger

logger = getLogger()


def flaskpath2openapi(path: str) -> str:
"""Convert a Flask URL rule to an OpenAPI-compliant path.
Got from https://github.com/marshmallow-code/apispec-webframeworks/
:param str path: Flask path template.
"""
# from flask-restplus
re_url = re.compile(r"<(?:[^:<>]+:)?([^<>]+)>")
return re_url.sub(r"{\1}", path)


class Resource(MethodView):
init_every_request = False

def dispatch_request(self, *args, **kwargs):
method = request.method.lower()
if not hasattr(self, method):
raise MethodNotAllowed(
valid_methods=self.methods,
description="Method is not allowed for this endpoint",
)
response = getattr(self, method)(*args, **kwargs)
if isinstance(response, Response):
return response
return jsonify(response)


class Service:
description = textwrap.dedent(
"""
MWDB API documentation.
If you want to automate things, we recommend using
<a href="https://github.com/CERT-Polska/mwdblib">
mwdblib library
</a>
"""
)
servers = [
{
"url": "{scheme}://{host}",
"description": "MWDB API endpoint",
"variables": {
"scheme": {"enum": ["http", "https"], "default": "https"},
"host": {"default": "mwdb.cert.pl"},
},
}
]

def __init__(self, app: Flask) -> None:
self.app = app
self.blueprint = Blueprint("api", __name__, url_prefix="/api")
self.spec = APISpec(
title="MWDB Core",
version=app_version,
openapi_version="3.0.2",
plugins=[ApispecFlaskRestful(), MarshmallowPlugin()],
plugins=[MarshmallowPlugin()],
info={"description": self.description},
servers=self.servers,
)

spec.components.security_scheme(
self.spec.components.security_scheme(
"bearerAuth", {"type": "http", "scheme": "bearer", "bearerFormat": "JWT"}
)
spec.options["info"] = {
"description": textwrap.dedent(
"""
MWDB API documentation.

If you want to automate things, we recommend using
<a href="http://github.com/CERT-Polska/mwdblib">mwdblib library</a>"""
def _make_error_response(self, exc: HTTPException) -> ResponseReturnValue:
return jsonify({"message": exc.description}), exc.code

def error_handler(self, exc: Exception) -> ResponseReturnValue:
if isinstance(exc, HTTPException):
return self._make_error_response(exc)
elif isinstance(exc, OperationalError):
return self._make_error_response(
ServiceUnavailable(
description="Request canceled due to statement timeout"
)
)
}
spec.options["servers"] = [
{
"url": "{scheme}://{host}",
"description": "MWDB API endpoint",
"variables": {
"scheme": {"enum": ["http", "https"], "default": "https"},
"host": {"default": "mwdb.cert.pl"},
},
}
]
return spec

def error_router(self, original_handler, e):
logger = log.getLogger()
if isinstance(e, HTTPException):
logger.error(str(e))
elif isinstance(e, OperationalError):
logger.error(str(e))
raise ServiceUnavailable("Request canceled due to statement timeout")
else:
logger.exception("Unhandled exception occurred")

# Handle all exceptions using handle_error, not only for owned routes
try:
return self.handle_error(e)
except Exception:
logger.exception("Exception from handle_error occurred")
pass
# If something went wrong - fallback to original behavior
return super().error_router(original_handler, e)

def add_resource(self, resource, *urls, undocumented=False, **kwargs):
super().add_resource(resource, *urls, **kwargs)
# Unknown exception, return ISE 500
logger.exception("Internal server error", exc_info=sys.exc_info())
return self._make_error_response(
InternalServerError(description="Internal server error")
)

def add_resource(
self, resource: Resource, *urls: str, undocumented: bool = False
) -> None:
view = resource.as_view(resource.__name__)
endpoint = view.__name__.lower()
for url in urls:
self.blueprint.add_url_rule(rule=url, endpoint=endpoint, view_func=view)
if not undocumented:
self.spec.path(resource=resource, api=self, app=self.flask_app)
resource_doc = resource.__doc__ or ""
operations = yaml_utils.load_operations_from_docstring(resource_doc)
for method in resource.methods:
method_name = method.lower()
method_doc = getattr(resource, method_name).__doc__
if method_doc:
operations[method_name] = yaml_utils.load_yaml_from_docstring(
method_doc
)
for url in urls:
prefixed_url = self.blueprint.url_prefix + "/" + url.lstrip("/")
self.spec.path(
path=flaskpath2openapi(prefixed_url), operations=operations
)

def register(self):
"""
Registers service and its blueprint to the app.
This must be done after adding all resources.
"""
# This handler is intentionally set on app and not blueprint
# to catch routing errors as well. The side effect is that
# it will return jsonified error messages for static endpoints
# but static files should be handled by separate server anyway...
self.app.register_error_handler(Exception, self.error_handler)
self.app.register_blueprint(self.blueprint)

def relative_url_for(self, resource, **values):
path = self.url_for(resource, **values)
# TODO: Remove this along with legacy download endpoint
endpoint = self.blueprint.name + "." + resource.__name__.lower()
path = self.app.url_for(endpoint, **values)
return path[len(self.blueprint.url_prefix) :]

def endpoint_for(self, resource):
return f"{self.blueprint.name}.{resource}"
Loading

0 comments on commit 3b11349

Please sign in to comment.