diff --git a/django_query_prefixer/middlewares/__init__.py b/django_query_prefixer/middlewares/__init__.py index e6f7933..512ad76 100644 --- a/django_query_prefixer/middlewares/__init__.py +++ b/django_query_prefixer/middlewares/__init__.py @@ -1,21 +1,27 @@ from django.urls import resolve -from django_query_prefixer import sql_prefixes +from django_query_prefixer import set_prefix, remove_prefix + + +class RequestRouteMiddleware: + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + response = self.get_response(request) + remove_prefix("view_name") + remove_prefix("route") + return response + + def process_view(self, request, view_func, view_args, view_kwargs): + set_prefix(key="view_name", value=f"{view_func.__module__}.{view_func.__name__}") + set_prefix( + key="route", + value=escape_comment_markers(request.resolver_match.route.route) + ) def request_route(get_response): - def middleware(request): - if request.resolver_match is not None: - route = request.resolver_match.route - else: - route = resolve(request.path_info) - - with sql_prefixes( - view_name=route.view_name, - route=escape_comment_markers(route.route), - ): - return get_response(request) - - return middleware + return RequestRouteMiddleware(get_response) def escape_comment_markers(value): diff --git a/tests/test_middlewares.py b/tests/test_middlewares.py index bb71672..868cd1b 100644 --- a/tests/test_middlewares.py +++ b/tests/test_middlewares.py @@ -10,14 +10,15 @@ def test_request_route_middleware(): request = mock.MagicMock() request.resolver_match.route.route = "/hello" - request.resolver_match.route.view_name = "hello_world" + def hello_world(): + pass - with mock.patch("django_query_prefixer.middlewares.sql_prefixes") as mock_sql_prefixes: + with mock.patch("django_query_prefixer.middlewares.set_prefix") as mock_set_prefix, \ + mock.patch("django_query_prefixer.middlewares.remove_prefix") as mock_remove_prefix: + middleware.process_view(request, hello_world, [], {}) assert middleware(request) == response - mock_sql_prefixes.assert_called_with( - view_name="hello_world", - route="/hello", - ) - - get_response.assert_called_once_with(request) + assert mock_set_prefix.call_args_list[0].kwargs == {"key": "view_name", "value": f"{hello_world.__module__}.hello_world"} + assert mock_set_prefix.call_args_list[1].kwargs == {"key": "route", "value": "/hello"} + assert mock_remove_prefix.call_args_list[0].args[0] == "view_name" + assert mock_remove_prefix.call_args_list[1].args[0] == "route"