Skip to content

Commit

Permalink
Fix awaiting of async view function.
Browse files Browse the repository at this point in the history
  • Loading branch information
maxrimlinger committed Feb 27, 2024
1 parent 42769af commit b3cfb01
Showing 1 changed file with 29 additions and 10 deletions.
39 changes: 29 additions & 10 deletions flask_parameter_validation/parameter_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import inspect
import re
from inspect import signature
from flask import request
from flask import request, Response
from werkzeug.datastructures import ImmutableMultiDict
from werkzeug.exceptions import BadRequest
from .exceptions import (InvalidParameterTypeError, MissingInputError,
Expand Down Expand Up @@ -40,8 +40,13 @@ def __call__(self, f):
}
fn_list[fsig] = fdocs

@functools.wraps(f)
async def nested_func(**kwargs):
def nested_func_helper(**kwargs):
"""
Validates the inputs of a Flask route or returns an error. Returns
are wrapped in a dictionary with a flag to let nested_func() know
if it should unpack the resulting dictionary of inputs as kwargs,
or just return the error message.
"""
# Step 1 - Get expected input details as dict
expected_inputs = signature(f).parameters

Expand All @@ -54,7 +59,7 @@ async def nested_func(**kwargs):
try:
json_input = request.json
except BadRequest:
return {"error": "Could not parse JSON."}, 400
return {"error": ({"error": "Could not parse JSON."}, 400), "validated": False}

# Step 3 - Extract list of parameters expected to be lists (otherwise all values are converted to lists)
expected_list_params = []
Expand All @@ -79,18 +84,32 @@ async def nested_func(**kwargs):
try:
new_input = self.validate(expected, request_inputs)
except (MissingInputError, ValidationError) as e:
return {"error": str(e)}, 400
return {"error": ({"error": str(e)}, 400), "validated": False}
else:
try:
new_input = self.validate(expected, request_inputs)
except Exception as e:
return self.custom_error_handler(e)
return {"error": self.custom_error_handler(e), "validated": False}
validated_inputs[expected.name] = new_input

if asyncio.iscoroutinefunction(f):
return await f(**validated_inputs)
else:
return f(**validated_inputs)
return {"inputs": validated_inputs, "validated": True}

if asyncio.iscoroutinefunction(f):
# If the view function is async, return and await a coroutine
@functools.wraps(f)
async def nested_func(**kwargs):
validated_inputs = nested_func_helper(**kwargs)
if validated_inputs["validated"]:
return await f(**validated_inputs["inputs"])
return validated_inputs["error"]
else:
# If the view function is not async, return a function
@functools.wraps(f)
def nested_func(**kwargs):
validated_inputs = nested_func_helper(**kwargs)
if validated_inputs["validated"]:
return f(**validated_inputs["inputs"])
return validated_inputs["error"]

nested_func.__name__ = f.__name__
return nested_func
Expand Down

0 comments on commit b3cfb01

Please sign in to comment.