From b3cfb0103600bd3c7643902a396439216d35424f Mon Sep 17 00:00:00 2001 From: maxwrimlinger <37249275+maxwrimlinger@users.noreply.github.com> Date: Mon, 26 Feb 2024 23:58:10 -0500 Subject: [PATCH] Fix awaiting of async view function. --- .../parameter_validation.py | 39 ++++++++++++++----- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/flask_parameter_validation/parameter_validation.py b/flask_parameter_validation/parameter_validation.py index 2f56502..20f72c7 100644 --- a/flask_parameter_validation/parameter_validation.py +++ b/flask_parameter_validation/parameter_validation.py @@ -3,7 +3,7 @@ import inspect import re from inspect import signature -from flask import request +from flask import request, Response from werkzeug.datastructures import ImmutableMultiDict from werkzeug.exceptions import BadRequest from .exceptions import (InvalidParameterTypeError, MissingInputError, @@ -40,8 +40,13 @@ def __call__(self, f): } fn_list[fsig] = fdocs - @functools.wraps(f) - async def nested_func(**kwargs): + def nested_func_helper(**kwargs): + """ + Validates the inputs of a Flask route or returns an error. Returns + are wrapped in a dictionary with a flag to let nested_func() know + if it should unpack the resulting dictionary of inputs as kwargs, + or just return the error message. + """ # Step 1 - Get expected input details as dict expected_inputs = signature(f).parameters @@ -54,7 +59,7 @@ async def nested_func(**kwargs): try: json_input = request.json except BadRequest: - return {"error": "Could not parse JSON."}, 400 + return {"error": ({"error": "Could not parse JSON."}, 400), "validated": False} # Step 3 - Extract list of parameters expected to be lists (otherwise all values are converted to lists) expected_list_params = [] @@ -79,18 +84,32 @@ async def nested_func(**kwargs): try: new_input = self.validate(expected, request_inputs) except (MissingInputError, ValidationError) as e: - return {"error": str(e)}, 400 + return {"error": ({"error": str(e)}, 400), "validated": False} else: try: new_input = self.validate(expected, request_inputs) except Exception as e: - return self.custom_error_handler(e) + return {"error": self.custom_error_handler(e), "validated": False} validated_inputs[expected.name] = new_input - if asyncio.iscoroutinefunction(f): - return await f(**validated_inputs) - else: - return f(**validated_inputs) + return {"inputs": validated_inputs, "validated": True} + + if asyncio.iscoroutinefunction(f): + # If the view function is async, return and await a coroutine + @functools.wraps(f) + async def nested_func(**kwargs): + validated_inputs = nested_func_helper(**kwargs) + if validated_inputs["validated"]: + return await f(**validated_inputs["inputs"]) + return validated_inputs["error"] + else: + # If the view function is not async, return a function + @functools.wraps(f) + def nested_func(**kwargs): + validated_inputs = nested_func_helper(**kwargs) + if validated_inputs["validated"]: + return f(**validated_inputs["inputs"]) + return validated_inputs["error"] nested_func.__name__ = f.__name__ return nested_func