Skip to content

Commit

Permalink
Adding json_serialize and json_deserialize to requests transport (#466)
Browse files Browse the repository at this point in the history
  • Loading branch information
leszekhanusz authored Feb 8, 2024
1 parent a3f0bd9 commit e5c7c8f
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 6 deletions.
23 changes: 17 additions & 6 deletions gql/transport/requests.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import io
import json
import logging
from typing import Any, Collection, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Type, Union

import requests
from graphql import DocumentNode, ExecutionResult, print_ast
Expand Down Expand Up @@ -47,6 +47,8 @@ def __init__(
method: str = "POST",
retry_backoff_factor: float = 0.1,
retry_status_forcelist: Collection[int] = _default_retry_codes,
json_serialize: Callable = json.dumps,
json_deserialize: Callable = json.loads,
**kwargs: Any,
):
"""Initialize the transport with the given request parameters.
Expand All @@ -73,6 +75,10 @@ def __init__(
should force a retry on. A retry is initiated if the request method is
in allowed_methods and the response status code is in status_forcelist.
(Default: [429, 500, 502, 503, 504])
:param json_serialize: Json serializer callable.
By default json.dumps() function
:param json_deserialize: Json deserializer callable.
By default json.loads() function
:param kwargs: Optional arguments that ``request`` takes.
These can be seen at the `requests`_ source code or the official `docs`_
Expand All @@ -90,6 +96,8 @@ def __init__(
self.method = method
self.retry_backoff_factor = retry_backoff_factor
self.retry_status_forcelist = retry_status_forcelist
self.json_serialize: Callable = json_serialize
self.json_deserialize: Callable = json_deserialize
self.kwargs = kwargs

self.session = None
Expand Down Expand Up @@ -174,7 +182,7 @@ def execute( # type: ignore
payload["variables"] = nulled_variable_values

# Add the payload to the operations field
operations_str = json.dumps(payload)
operations_str = self.json_serialize(payload)
log.debug("operations %s", operations_str)

# Generate the file map
Expand All @@ -188,7 +196,7 @@ def execute( # type: ignore
file_streams = {str(i): files[path] for i, path in enumerate(files)}

# Add the file map field
file_map_str = json.dumps(file_map)
file_map_str = self.json_serialize(file_map)
log.debug("file_map %s", file_map_str)

fields = {"operations": operations_str, "map": file_map_str}
Expand Down Expand Up @@ -224,7 +232,7 @@ def execute( # type: ignore

# Log the payload
if log.isEnabledFor(logging.INFO):
log.info(">>> %s", json.dumps(payload))
log.info(">>> %s", self.json_serialize(payload))

# Pass kwargs to requests post method
post_args.update(self.kwargs)
Expand Down Expand Up @@ -257,7 +265,10 @@ def raise_response_error(resp: requests.Response, reason: str):
)

try:
result = response.json()
if self.json_deserialize == json.loads:
result = response.json()
else:
result = self.json_deserialize(response.text)

if log.isEnabledFor(logging.INFO):
log.info("<<< %s", response.text)
Expand Down Expand Up @@ -396,7 +407,7 @@ def _build_batch_post_args(

# Log the payload
if log.isEnabledFor(logging.INFO):
log.info(">>> %s", json.dumps(post_args[data_key]))
log.info(">>> %s", self.json_serialize(post_args[data_key]))

# Pass kwargs to requests post method
post_args.update(self.kwargs)
Expand Down
106 changes: 106 additions & 0 deletions tests/test_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,3 +923,109 @@ def test_code():
assert transport.session is None

await run_sync_test(event_loop, server, test_code)


@pytest.mark.aiohttp
@pytest.mark.asyncio
async def test_requests_json_serializer(
event_loop, aiohttp_server, run_sync_test, caplog
):
import json
from aiohttp import web
from gql.transport.requests import RequestsHTTPTransport

async def handler(request):

request_text = await request.text()
print("Received on backend: " + request_text)

return web.Response(
text=query1_server_answer,
content_type="application/json",
)

app = web.Application()
app.router.add_route("POST", "/", handler)
server = await aiohttp_server(app)

url = server.make_url("/")

def test_code():
transport = RequestsHTTPTransport(
url=url,
json_serialize=lambda e: json.dumps(e, separators=(",", ":")),
)

with Client(transport=transport) as session:

query = gql(query1_str)

# Execute query asynchronously
result = session.execute(query)

continents = result["continents"]

africa = continents[0]

assert africa["code"] == "AF"

# Checking that there is no space after the colon in the log
expected_log = '"query":"query getContinents'
assert expected_log in caplog.text

await run_sync_test(event_loop, server, test_code)


query_float_str = """
query getPi {
pi
}
"""

query_float_server_answer_data = '{"pi": 3.141592653589793238462643383279502884197}'

query_float_server_answer = f'{{"data":{query_float_server_answer_data}}}'


@pytest.mark.aiohttp
@pytest.mark.asyncio
async def test_requests_json_deserializer(event_loop, aiohttp_server, run_sync_test):
import json
from aiohttp import web
from decimal import Decimal
from functools import partial
from gql.transport.requests import RequestsHTTPTransport

async def handler(request):
return web.Response(
text=query_float_server_answer,
content_type="application/json",
)

app = web.Application()
app.router.add_route("POST", "/", handler)
server = await aiohttp_server(app)

url = server.make_url("/")

def test_code():

json_loads = partial(json.loads, parse_float=Decimal)

transport = RequestsHTTPTransport(
url=url,
json_deserialize=json_loads,
)

with Client(transport=transport) as session:

query = gql(query_float_str)

# Execute query asynchronously
result = session.execute(query)

pi = result["pi"]

assert pi == Decimal("3.141592653589793238462643383279502884197")

await run_sync_test(event_loop, server, test_code)

0 comments on commit e5c7c8f

Please sign in to comment.