diff --git a/graphql_persist/middleware.py b/graphql_persist/middleware.py index 548b1e3..2b7ea08 100644 --- a/graphql_persist/middleware.py +++ b/graphql_persist/middleware.py @@ -5,7 +5,7 @@ from . import exceptions from .loaders import DocumentDoesNotExist, DocumentImportError -from .parser import parse_json +from .parser import parse_body, parse_json from .query import QueryKey from .settings import persist_settings @@ -45,12 +45,11 @@ def __call__(self, request): def process_view(self, request, view_func, *args): if (hasattr(view_func, 'view_class') and - issubclass(view_func.view_class, GraphQLView) and - request.content_type == 'application/json'): + issubclass(view_func.view_class, GraphQLView)): - try: - data = parse_json(request.body) - except ValueError: + data = parse_body(request) + + if data is None: return None query_id = data.get('id', data.get('operationName')) @@ -68,7 +67,11 @@ def process_view(self, request, view_func, *args): return exceptions.DocumentSyntaxError(str(e)) request.persisted_query = PersistedQuery(document, data) - request._body = json.dumps(data).encode() + + if request.content_type == 'application/json': + request._body = json.dumps(data).encode() + else: + request.POST = data return None def get_query_key(self, query_id, request): diff --git a/graphql_persist/parser.py b/graphql_persist/parser.py index 600d1ab..10f454c 100644 --- a/graphql_persist/parser.py +++ b/graphql_persist/parser.py @@ -7,3 +7,18 @@ def parse_json(content, **kwargs): content = force_text(content, **kwargs) return json.loads(content, object_pairs_hook=OrderedDict) + + +def parse_body(request): + if request.content_type == 'application/json': + try: + return parse_json(request.body) + except ValueError: + return None + + elif request.content_type in ( + 'application/x-www-form-urlencoded', + 'multipart/form-data'): + + return request.POST.dict() + return None diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 17ed004..e6b73d2 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -9,18 +9,21 @@ from graphql_persist import exceptions, versioning from graphql_persist.middleware import PersistedQuery, PersistMiddleware -from graphql_persist.parser import parse_json +from graphql_persist.parser import parse_body, parse_json from graphql_persist.renderers import BaseRenderer from graphql_persist.settings import persist_settings from .decorators import override_persist_settings -class JSONRequestFactory(RequestFactory): +class VersioningRequestFactory(RequestFactory): def request(self, **request): return VersioningRequest(self._base_environ(**request)) + +class JSONRequestFactory(VersioningRequestFactory): + def post(self, path, data=None, *args, **kwargs): kwargs.setdefault('content_type', 'application/json') @@ -62,12 +65,29 @@ def test_process_view(self): result = self.middleware.process_view(request, self.view_func) persisted_query = request.persisted_query document = persisted_query.document - request_body = parse_json(request._body) + body = parse_body(request) + + self.assertIsNone(result) + self.assertEqual(persisted_query.id, 'schema') + self.assertEqual(document.origin.query_key._keys, ['schema']) + self.assertEqual(document.source.body, body['query']) + + @override_settings(INSTALLED_APPS=['tests']) + def test_process_view_x_www_form_urlencoded(self): + data = { + 'id': 'schema', + } + + request = VersioningRequestFactory().post('/', data=data) + result = self.middleware.process_view(request, self.view_func) + persisted_query = request.persisted_query + document = persisted_query.document + body = request.POST self.assertIsNone(result) self.assertEqual(persisted_query.id, 'schema') self.assertEqual(document.origin.query_key._keys, ['schema']) - self.assertEqual(document.source.body, request_body['query']) + self.assertEqual(document.source.body, body['query']) def test_missing_id(self): request = self.factory.post('/', data={}) diff --git a/tests/test_parser.py b/tests/test_parser.py new file mode 100644 index 0000000..24af541 --- /dev/null +++ b/tests/test_parser.py @@ -0,0 +1,44 @@ +import json + +from django.test import RequestFactory, testcases + +from graphql_persist import parser + + +class ParserTests(testcases.TestCase): + + def setUp(self): + self.factory = RequestFactory() + + def test_application_json(self): + request = self.factory.post( + '/', + data=json.dumps({'test': True}), + content_type='application/json') + + result = parser.parse_body(request) + self.assertTrue(result['test']) + + def test_json_decode_error(self): + request = self.factory.post( + '/', + data='error', + content_type='application/json') + + result = parser.parse_body(request) + self.assertIsNone(result) + + def test_x_www_form_urlencoded(self): + request = self.factory.post('/', data={'test': True}) + result = parser.parse_body(request) + + self.assertTrue(eval(result['test'])) + + def test_unknown_content_type(self): + request = self.factory.post( + '/', + data={'test': True}, + content_type='unknown') + + result = parser.parse_body(request) + self.assertIsNone(result)