Skip to content

Commit

Permalink
feat: adding support for Django Ninja test client
Browse files Browse the repository at this point in the history
  • Loading branch information
maticardenas committed Apr 4, 2024
1 parent 02743d1 commit 72c78a1
Show file tree
Hide file tree
Showing 9 changed files with 218 additions and 100 deletions.
54 changes: 50 additions & 4 deletions openapi_tester/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@

from __future__ import annotations

import http
from typing import TYPE_CHECKING

# pylint: disable=import-error
from ninja import NinjaAPI, Router
from ninja.testing import TestClient
from rest_framework.test import APIClient

from .response_handler_factory import ResponseHandlerFactory
from .schema_tester import SchemaTester

if TYPE_CHECKING:
Expand All @@ -25,17 +30,58 @@ def __init__(
super().__init__(*args, **kwargs)
self.schema_tester = schema_tester or self._schema_tester_factory()

def request(self, **kwargs) -> Response: # type: ignore[override]
def request(self, *args, **kwargs) -> Response: # type: ignore[override]
"""Validate fetched response against given OpenAPI schema."""
response = super().request(**kwargs)
response_handler = ResponseHandlerFactory.create(
*args, response=response, **kwargs
)
if self._is_successful_response(response):
self.schema_tester.validate_request(response)
self.schema_tester.validate_response(response)
self.schema_tester.validate_request(response_handler=response_handler)
self.schema_tester.validate_response(response_handler=response_handler)
return response

@staticmethod
def _is_successful_response(response: Response) -> bool:
return response.status_code < 400
return response.status_code < http.HTTPStatus.BAD_REQUEST

@staticmethod
def _schema_tester_factory() -> SchemaTester:
"""Factory of default ``SchemaTester`` instances."""
return SchemaTester()


# pylint: disable=too-few-public-methods
class OpenAPINinjaClient(TestClient):
"""``APINinjaClient`` validating responses against OpenAPI schema."""

def __init__(
self,
*args,
router_or_app: NinjaAPI | Router,
path_prefix: str = "",
schema_tester: SchemaTester | None = None,
**kwargs,
) -> None:
"""Initialize ``OpenAPIClient`` instance."""
super().__init__(*args, router_or_app=router_or_app, **kwargs)
self.schema_tester = schema_tester or self._schema_tester_factory()
self._ninja_path_prefix = path_prefix

def request(self, *args, **kwargs) -> Response:
"""Validate fetched response against given OpenAPI schema."""
response = super().request(*args, **kwargs)
response_handler = ResponseHandlerFactory.create(
*args, response=response, path_prefix=self._ninja_path_prefix, **kwargs
)
if self._is_successful_response(response):
self.schema_tester.validate_request(response_handler=response_handler)
self.schema_tester.validate_response(response_handler)
return response

@staticmethod
def _is_successful_response(response: Response) -> bool:
return response.status_code < http.HTTPStatus.BAD_REQUEST

@staticmethod
def _schema_tester_factory() -> SchemaTester:
Expand Down
55 changes: 45 additions & 10 deletions openapi_tester/response_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,26 @@
"""

import json
from typing import TYPE_CHECKING, Optional, Union
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional, Union

if TYPE_CHECKING:
from django.http.response import HttpResponse
from rest_framework.response import Response


class ResponseHandler:
@dataclass
class GenericRequest:
"""Generic request class for both DRF and Django Ninja."""

path: str
method: str
data: dict = field(default_factory=dict)
headers: dict = field(default_factory=dict)


class ResponseHandler(ABC):
"""
This class is used to handle the response and request data
from both DRF and Django HTTP (Django Ninja) responses.
Expand All @@ -24,10 +36,12 @@ def response(self) -> Union["Response", "HttpResponse"]:
return self._response

@property
def data(self) -> Optional[dict]: ...
@abstractmethod
def request(self) -> GenericRequest: ...

@property
def request_data(self) -> Optional[dict]: ...
@abstractmethod
def data(self) -> Optional[dict]: ...


class DRFResponseHandler(ResponseHandler):
Expand All @@ -43,23 +57,44 @@ def data(self) -> Optional[dict]:
return self.response.json() if self.response.data is not None else None # type: ignore[attr-defined]

@property
def request_data(self) -> Optional[dict]:
return self.response.renderer_context["request"].data # type: ignore[attr-defined]
def request(self) -> GenericRequest:
return GenericRequest(
path=self.response.renderer_context["request"].path, # type: ignore[attr-defined]
method=self.response.renderer_context["request"].method, # type: ignore[attr-defined]
data=self.response.renderer_context["request"].data, # type: ignore[attr-defined]
headers=self.response.renderer_context["request"].headers, # type: ignore[attr-defined]
)


class DjangoNinjaResponseHandler(ResponseHandler):
"""
Handles the response and request data from Django Ninja responses.
"""

def __init__(self, response: "HttpResponse") -> None:
def __init__(
self, *request_args, response: "HttpResponse", path_prefix: str = "", **kwargs
) -> None:
super().__init__(response)
self._request_method = request_args[0]
self._request_path = f"{path_prefix}{request_args[1]}"
self._request_data = self._build_request_data(request_args[2])
self._request_headers = kwargs

@property
def data(self) -> Optional[dict]:
return self.response.json() if self.response.content else None # type: ignore[attr-defined]

@property
def request_data(self) -> Optional[dict]:
response_body = self.response.wsgi_request.body # type: ignore[attr-defined]
return json.loads(response_body) if response_body else None
def request(self) -> GenericRequest:
return GenericRequest(
path=self._request_path,
method=self._request_method,
data=self._request_data,
headers=self._request_headers,
)

def _build_request_data(self, request_data: Any) -> dict:
try:
return json.loads(request_data)
except (json.JSONDecodeError, TypeError, ValueError):
return {}
8 changes: 5 additions & 3 deletions openapi_tester/response_handler_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ class ResponseHandlerFactory:
"""

@staticmethod
def create(response: Union[Response, "HttpResponse"]) -> "ResponseHandler":
def create(
*request_args, response: Union[Response, "HttpResponse"], **kwargs
) -> "ResponseHandler":
if isinstance(response, Response):
return DRFResponseHandler(response)
return DjangoNinjaResponseHandler(response)
return DRFResponseHandler(response=response)
return DjangoNinjaResponseHandler(*request_args, response=response, **kwargs)
42 changes: 21 additions & 21 deletions openapi_tester/schema_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
StaticSchemaLoader,
UrlStaticSchemaLoader,
)
from openapi_tester.response_handler_factory import ResponseHandlerFactory
from openapi_tester.utils import lazy_combinations, normalize_schema_section
from openapi_tester.validators import (
validate_enum,
Expand All @@ -54,10 +53,7 @@
if TYPE_CHECKING:
from typing import Optional

from django.http.response import HttpResponse
from rest_framework.response import Response

from openapi_tester.response_handler import ResponseHandler
from openapi_tester.response_handler import GenericRequest, ResponseHandler


class SchemaTester:
Expand Down Expand Up @@ -180,9 +176,9 @@ def get_response_schema_section(
response = response_handler.response
schema = self.loader.get_schema()

response_method = response.request["REQUEST_METHOD"].lower() # type: ignore
response_method = response_handler.request.method.lower()
parameterized_path, _ = self.loader.resolve_path(
response.request["PATH_INFO"], # type: ignore
response_handler.request.path,
method=response_method,
)
paths_object = self.get_paths_object()
Expand Down Expand Up @@ -246,19 +242,20 @@ def get_response_schema_section(
return {}

def get_request_body_schema_section(
self, request: dict[str, Any]
self, request: GenericRequest
) -> dict[str, Any]:
"""
Fetches the request section of a schema.
:param response: DRF Request Instance
:return dict
"""
request_method = request["REQUEST_METHOD"].lower()
request_method = request.method.lower() # request["REQUEST_METHOD"].lower()

parameterized_path, _ = self.loader.resolve_path(
request["PATH_INFO"], method=request_method
request.path, method=request_method
)

paths_object = self.get_paths_object()

route_object = self.get_key_value(
Expand All @@ -277,22 +274,25 @@ def get_request_body_schema_section(
),
)

if all(
key in request for key in ["CONTENT_LENGTH", "CONTENT_TYPE", "wsgi.input"]
):
if request["CONTENT_TYPE"] != "application/json":
if request.data:
if not any(
"application/json" == request.headers.get(content_type)
for content_type in ["content_type", "CONTENT_TYPE", "Content-Type"]
):
return {}

request_body_object = self.get_key_value(
method_object,
"requestBody",
f"\n\nNo request body documented for method: {request_method}, path: {parameterized_path}",
)

content_object = self.get_key_value(
request_body_object,
"content",
f"\n\nNo content documented for method: {request_method}, path: {parameterized_path}",
)

json_object = self.get_key_value(
content_object,
r"^application\/.*json$",
Expand All @@ -302,6 +302,7 @@ def get_request_body_schema_section(
),
use_regex=True,
)

return self.get_key_value(json_object, "schema")

return {}
Expand Down Expand Up @@ -568,7 +569,7 @@ def test_openapi_array(

def validate_request(
self,
response: Response | HttpResponse,
response_handler: ResponseHandler,
test_config: OpenAPITestConfig | None = None,
) -> None:
"""
Expand All @@ -582,25 +583,26 @@ def validate_request(
:raises: ``openapi_tester.exceptions.DocumentationError`` for inconsistencies in the API response and schema.
``openapi_tester.exceptions.CaseError`` for case errors.
"""
response_handler = ResponseHandlerFactory.create(response)
if self.is_openapi_schema():
# TODO: Implement for other schema types
if test_config:
test_config.http_message = "request"
else:
test_config = OpenAPITestConfig(http_message="request")
request_body_schema = self.get_request_body_schema_section(response.request) # type: ignore
request_body_schema = self.get_request_body_schema_section(
response_handler.request
)

if request_body_schema:
self.test_schema_section(
schema_section=request_body_schema,
data=response_handler.request_data,
data=response_handler.request.data,
test_config=test_config,
)

def validate_response(
self,
response: Response | HttpResponse,
response_handler: ResponseHandler,
test_config: OpenAPITestConfig | None = None,
) -> None:
"""
Expand All @@ -613,8 +615,6 @@ def validate_response(
:raises: ``openapi_tester.exceptions.DocumentationError`` for inconsistencies in the API response and schema.
``openapi_tester.exceptions.CaseError`` for case errors.
"""
response_handler = ResponseHandlerFactory.create(response)

if test_config:
test_config.http_message = "response"
else:
Expand Down
Loading

0 comments on commit 72c78a1

Please sign in to comment.