Skip to content

Commit

Permalink
request.openapi_validated does not break non-opeanpi views
Browse files Browse the repository at this point in the history
Accessing `request.openapi_validated` in a route that has
`view_config(openapi=False)` will no longer break the route.

Refs #165
  • Loading branch information
zupo committed May 27, 2022
1 parent 29f0180 commit 4ca7f43
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 17 deletions.
36 changes: 30 additions & 6 deletions pyramid_openapi3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,14 @@ def includeme(config: Configurator) -> None:

def openapi_validated(request: Request) -> dict:
"""Get validated parameters."""
# Validate request and attach all findings for view to introspect

# we need this here in case someone calls request.openapi_validated on
# a view marked with openapi=False
if not request.environ.get("pyramid_openapi3.enabled"):
raise AttributeError(
"Cannot do openapi request validation on a view marked with openapi=False"
)

validate_request = asbool(
request.registry.settings.get(
"pyramid_openapi3.enable_request_validation", True
Expand All @@ -83,8 +90,8 @@ def openapi_validated(request: Request) -> dict:
if route_settings and request.matched_route.name in route_settings:
settings = request.registry.settings[route_settings[request.matched_route.name]]

if validate_request: # pragma: no branch
request.environ["pyramid_openapi3.validate_request"] = True
# if request.environ.get("pyramid_openapi3.validate_request"):
if validate_request:
openapi_request = PyramidOpenAPIRequestFactory.create(request)
validated = settings["request_validator"].validate(openapi_request)
return validated
Expand All @@ -97,7 +104,7 @@ def openapi_validated(request: Request) -> dict:


def openapi_view(view: View, info: ViewDeriverInfo) -> View:
"""View deriver that takes care of request/response validation.
"""View deriver that takes care of request validation.
If `openapi=True` is passed to `@view_config`, this decorator will:
Expand All @@ -109,12 +116,29 @@ def openapi_view(view: View, info: ViewDeriverInfo) -> View:
if info.options.get("openapi"):

def wrapper_view(context: Context, request: Request) -> Response:
validate_request = asbool(

# We need this to be able to raise AttributeError if view code
# accesses request.openapi_validated on a view that is marked
# with openapi=False
request.environ["pyramid_openapi3.enabled"] = True

# If view is marked with openapi=True (i.e. we are in this
# function) and registry settings are not set to disable
# validation, then do request/response validation
request.environ["pyramid_openapi3.validate_request"] = asbool(
request.registry.settings.get(
"pyramid_openapi3.enable_request_validation", True
)
)
if validate_request and request.openapi_validated.errors:
request.environ["pyramid_openapi3.validate_response"] = asbool(
request.registry.settings.get(
"pyramid_openapi3.enable_response_validation", True
)
)

# request validation can happen already here, but response validation
# needs to happen later in a tween
if request.openapi_validated and request.openapi_validated.errors:
raise RequestValidationError(errors=request.openapi_validated.errors)

# Do the view
Expand Down
18 changes: 8 additions & 10 deletions pyramid_openapi3/tests/test_path_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,8 @@
from webtest.app import TestApp


class _FooResource:
def __init__(self, request: Request) -> None:
self.request = request
self.foo_id = request.openapi_validated.parameters["path"]["foo_id"]


def _foo_view(context: _FooResource, request: Request) -> int:
return context.foo_id
def _foo_view(request: Request) -> int:
return request.openapi_validated.parameters["path"]["foo_id"]


def test_path_parameter_validation() -> None:
Expand Down Expand Up @@ -46,8 +40,12 @@ def test_path_parameter_validation() -> None:
config.include("pyramid_openapi3")
config.pyramid_openapi3_spec(tempdoc.name)
config.pyramid_openapi3_register_routes()
config.add_route("foo_route", "/foo/{foo_id}", factory=_FooResource)
config.add_view(_foo_view, route_name="foo_route", renderer="json")
# config.add_route("foo_route", "/foo/{foo_id}", factory=_FooResource)
config.add_route("foo_route", "/foo/{foo_id}")
config.add_view(
openapi=True, view=_foo_view, route_name="foo_route", renderer="json"
)

app = config.make_wsgi_app()
test_app = TestApp(app)
resp = test_app.get("/foo/1")
Expand Down
25 changes: 25 additions & 0 deletions pyramid_openapi3/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,31 @@ def test_nonapi_view(self) -> None:
self.assertEqual(start_response.status, "200 OK")
self.assertIn(b"foo", b"".join(response))

def test_nonapi_view_raises_AttributeError(self) -> None:
"""Test non-openapi view that accesses request.openapi_validated."""

def should_raise_error(request):
request.openapi_validated

self._add_view(openapi=False, view_func=should_raise_error)
# run request through router
router = Router(self.config.registry)
environ = {
"wsgi.url_scheme": "http",
"SERVER_NAME": "localhost",
"SERVER_PORT": "8080",
"REQUEST_METHOD": "GET",
"PATH_INFO": "/foo",
}
start_response = DummyStartResponse()
with self.assertRaises(AttributeError) as cm:
router(environ, start_response)

self.assertEqual(
str(cm.exception),
"Cannot do openapi request validation on a view marked with openapi=False",
)

def test_request_validation_disabled(self) -> None:
"""Test View with request validation disabled."""
self.config.registry.settings[
Expand Down
2 changes: 1 addition & 1 deletion pyramid_openapi3/tween.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def excview_tween(request: Request) -> Response:
try:
response = handler(request)
if not request.environ.get("pyramid_openapi3.validate_response"):
# not an openapi view or response validation not requested
return response

# validate response
openapi_request = PyramidOpenAPIRequestFactory.create(request)
openapi_response = PyramidOpenAPIResponseFactory.create(response)
Expand Down

0 comments on commit 4ca7f43

Please sign in to comment.