diff --git a/pyramid_openapi3/__init__.py b/pyramid_openapi3/__init__.py index c1a7c5a..ee2dde4 100644 --- a/pyramid_openapi3/__init__.py +++ b/pyramid_openapi3/__init__.py @@ -74,12 +74,24 @@ def openapi_validated(request: Request) -> dict: "Cannot do openapi request validation on a view marked with openapi=False" ) + validate_request = asbool( + request.registry.settings.get( + "pyramid_openapi3.enable_request_validation", True + ) + ) + validate_response = asbool( + request.registry.settings.get( + "pyramid_openapi3.enable_response_validation", True + ) + ) + request.environ["pyramid_openapi3.validate_response"] = validate_response gsettings = settings = request.registry.settings["pyramid_openapi3"] route_settings = gsettings.get("routes") if route_settings and request.matched_route.name in route_settings: settings = request.registry.settings[route_settings[request.matched_route.name]] - if request.environ.get("pyramid_openapi3.validate_request"): + # if request.environ.get("pyramid_openapi3.validate_request"): + if validate_request: openapi_request = PyramidOpenAPIRequestFactory.create(request) validated = settings["request_validator"].validate(openapi_request) return validated @@ -104,8 +116,7 @@ def openapi_view(view: View, info: ViewDeriverInfo) -> View: if info.options.get("openapi"): def wrapper_view(context: Context, request: Request) -> Response: - __import__("pdb").set_trace() - # __import__("pdb").set_trace() + # We need this to be able to raise AttributeError if view code # accesses request.openapi_validated on a view that is marked # with openapi=False diff --git a/pyramid_openapi3/tests/test_path_parameters.py b/pyramid_openapi3/tests/test_path_parameters.py index 7f2b257..5ca8beb 100644 --- a/pyramid_openapi3/tests/test_path_parameters.py +++ b/pyramid_openapi3/tests/test_path_parameters.py @@ -6,14 +6,8 @@ from webtest.app import TestApp -class _FooResource: - def __init__(self, request: Request) -> None: - self.request = request - self.foo_id = request.openapi_validated.parameters["path"]["foo_id"] - - -def _foo_view(context: _FooResource, request: Request) -> int: - return context.foo_id +def _foo_view(request: Request) -> int: + return request.openapi_validated.parameters["path"]["foo_id"] def test_path_parameter_validation() -> None: @@ -47,7 +41,7 @@ def test_path_parameter_validation() -> None: config.pyramid_openapi3_spec(tempdoc.name) config.pyramid_openapi3_register_routes() # config.add_route("foo_route", "/foo/{foo_id}", factory=_FooResource) - config.add_route("foo_route", "/foo/{foo_id}", factory=_FooResource) + config.add_route("foo_route", "/foo/{foo_id}") config.add_view( openapi=True, view=_foo_view, route_name="foo_route", renderer="json" )