Skip to content

Commit

Permalink
request.openapi_validated does not break non-opeanpi views
Browse files Browse the repository at this point in the history
Accessing `request.openapi_validated` in a route that has
`view_config(openapi=False)` will no longer break the route.

Also a fix to make response validation still work even if request validation
is turned off.

Finally, test_request_validation_disabled was a moot test: it would pass even
if validation was enabled, because the request was valid. Made sure both this
test and test_response_validation_disabled first verify the request/response
is invalid, then disable validation and test again.

Refs #165
  • Loading branch information
zupo committed Nov 28, 2022
1 parent 4d50bb7 commit aa924e4
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 35 deletions.
42 changes: 27 additions & 15 deletions pyramid_openapi3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,25 +66,20 @@ def includeme(config: Configurator) -> None:

def openapi_validated(request: Request) -> dict:
"""Get validated parameters."""
# Validate request and attach all findings for view to introspect
validate_request = asbool(
request.registry.settings.get(
"pyramid_openapi3.enable_request_validation", True
)
)
validate_response = asbool(
request.registry.settings.get(
"pyramid_openapi3.enable_response_validation", True

# We need this here in case someone calls request.openapi_validated on
# a view marked with openapi=False
if not request.environ.get("pyramid_openapi3.enabled"):
raise AttributeError(
"Cannot do openapi request validation on a view marked with openapi=False"
)
)
request.environ["pyramid_openapi3.validate_response"] = validate_response

gsettings = settings = request.registry.settings["pyramid_openapi3"]
route_settings = gsettings.get("routes")
if route_settings and request.matched_route.name in route_settings:
settings = request.registry.settings[route_settings[request.matched_route.name]]

if validate_request: # pragma: no branch
request.environ["pyramid_openapi3.validate_request"] = True
if request.environ.get("pyramid_openapi3.validate_request"):
openapi_request = PyramidOpenAPIRequestFactory.create(request)
validated = settings["request_validator"].validate(openapi_request)
return validated
Expand All @@ -109,12 +104,29 @@ def openapi_view(view: View, info: ViewDeriverInfo) -> View:
if info.options.get("openapi"):

def wrapper_view(context: Context, request: Request) -> Response:
validate_request = asbool(

# We need this to be able to raise AttributeError if view code
# accesses request.openapi_validated on a view that is marked
# with openapi=False
request.environ["pyramid_openapi3.enabled"] = True

# If view is marked with openapi=True (i.e. we are in this
# function) and registry settings are not set to disable
# validation, then do request/response validation
request.environ["pyramid_openapi3.validate_request"] = asbool(
request.registry.settings.get(
"pyramid_openapi3.enable_request_validation", True
)
)
if validate_request and request.openapi_validated.errors:
request.environ["pyramid_openapi3.validate_response"] = asbool(
request.registry.settings.get(
"pyramid_openapi3.enable_response_validation", True
)
)

# Request validation can happen already here, but response validation
# needs to happen later in a tween
if request.openapi_validated and request.openapi_validated.errors:
raise RequestValidationError(errors=request.openapi_validated.errors)

# Do the view
Expand Down
16 changes: 6 additions & 10 deletions pyramid_openapi3/tests/test_path_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,8 @@
from webtest.app import TestApp


class _FooResource:
def __init__(self, request: Request) -> None:
self.request = request
self.foo_id = request.openapi_validated.parameters["path"]["foo_id"]


def _foo_view(context: _FooResource, request: Request) -> int:
return context.foo_id
def _foo_view(request: Request) -> int:
return request.openapi_validated.parameters["path"]["foo_id"]


def test_path_parameter_validation() -> None:
Expand Down Expand Up @@ -46,8 +40,10 @@ def test_path_parameter_validation() -> None:
config.include("pyramid_openapi3")
config.pyramid_openapi3_spec(tempdoc.name)
config.pyramid_openapi3_register_routes()
config.add_route("foo_route", "/foo/{foo_id}", factory=_FooResource)
config.add_view(_foo_view, route_name="foo_route", renderer="json")
config.add_route("foo_route", "/foo/{foo_id}")
config.add_view(
_foo_view, route_name="foo_route", renderer="json", openapi=True
)
app = config.make_wsgi_app()
test_app = TestApp(app)
resp = test_app.get("/foo/1")
Expand Down
77 changes: 68 additions & 9 deletions pyramid_openapi3/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,13 +267,36 @@ def test_nonapi_view(self) -> None:
self.assertEqual(start_response.status, "200 OK")
self.assertIn(b"foo", b"".join(response))

def test_nonapi_view_raises_AttributeError(self) -> None:
"""Test non-openapi view that accesses request.openapi_validated."""

def should_raise_error(request: Request) -> None:
request.openapi_validated

self._add_view(openapi=False, view_func=should_raise_error)
# run request through router
router = Router(self.config.registry)
environ = {
"wsgi.url_scheme": "http",
"SERVER_NAME": "localhost",
"SERVER_PORT": "8080",
"REQUEST_METHOD": "GET",
"PATH_INFO": "/foo",
}
start_response = DummyStartResponse()
with self.assertRaises(AttributeError) as cm:
router(environ, start_response)

self.assertEqual(
str(cm.exception),
"Cannot do openapi request validation on a view marked with openapi=False",
)

def test_request_validation_disabled(self) -> None:
"""Test View with request validation disabled."""
self.config.registry.settings[
"pyramid_openapi3.enable_request_validation"
] = False
self._add_view(lambda *arg: {"test": "correct"})
# run request through router

# by default validation is enabled
router = Router(self.config.registry)
environ = {
"wsgi.url_scheme": "http",
Expand All @@ -282,21 +305,58 @@ def test_request_validation_disabled(self) -> None:
"REQUEST_METHOD": "GET",
"PATH_INFO": "/foo",
"HTTP_ACCEPT": "application/json",
"QUERY_STRING": "bar=1",
"QUERY_STRING": "bad=parameter",
}
start_response = DummyStartResponse()
response = router(environ, start_response)
self.assertEqual(start_response.status, "400 Bad Request")

# now let's disable it
self.config.registry.settings[
"pyramid_openapi3.enable_request_validation"
] = False
start_response = DummyStartResponse()
response = router(environ, start_response)
self.assertEqual(start_response.status, "200 OK")
self.assertEqual(json.loads(response[0]), {"test": "correct"})

def test_response_validation_disabled(self) -> None:
"""Test View with response validation disabled."""
self._add_view(lambda *arg: "not-valid")

# by default validation is enabled
router = Router(self.config.registry)
environ = {
"wsgi.url_scheme": "http",
"SERVER_NAME": "localhost",
"SERVER_PORT": "8080",
"REQUEST_METHOD": "GET",
"PATH_INFO": "/foo",
"HTTP_ACCEPT": "application/json",
"QUERY_STRING": "bar=1",
}
start_response = DummyStartResponse()
response = router(environ, start_response)
self.assertEqual(start_response.status, "500 Internal Server Error")

# now let's disable it
self.config.registry.settings[
"pyramid_openapi3.enable_response_validation"
] = False
start_response = DummyStartResponse()
response = router(environ, start_response)
self.assertEqual(start_response.status, "200 OK")
self.assertEqual(json.loads(response[0]), "not-valid")

def test_request_validation_disabled_response_validation_enabled(self) -> None:
"""Test response validation still works if request validation is disabled."""
self._add_view(lambda *arg: "not-valid")

# run request through router
self.config.registry.settings[
"pyramid_openapi3.enable_request_validation"
] = False

# by default validation is enabled
router = Router(self.config.registry)
environ = {
"wsgi.url_scheme": "http",
Expand All @@ -308,9 +368,8 @@ def test_response_validation_disabled(self) -> None:
"QUERY_STRING": "bar=1",
}
start_response = DummyStartResponse()
response = router(environ, start_response)
self.assertEqual(start_response.status, "200 OK")
self.assertIn(b'"not-valid"', response)
router(environ, start_response)
self.assertEqual(start_response.status, "500 Internal Server Error")


class TestImproperAPISpecValidation(RequestValidationBase):
Expand Down
2 changes: 1 addition & 1 deletion pyramid_openapi3/tween.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def excview_tween(request: Request) -> Response:
try:
response = handler(request)
if not request.environ.get("pyramid_openapi3.validate_response"):
# not an openapi view or response validation not requested
return response

# validate response
openapi_request = PyramidOpenAPIRequestFactory.create(request)
openapi_response = PyramidOpenAPIResponseFactory.create(response)
Expand Down

0 comments on commit aa924e4

Please sign in to comment.