Skip to content

Commit

Permalink
Add requestIds to server HTTP error response.
Browse files Browse the repository at this point in the history
HTTP errors are returned as JSON without requestId. The requestId is
required in order to form a ResponseMessage that can be properly
deserialized by the Java driver.
  • Loading branch information
kenhuuu committed Dec 6, 2023
1 parent 1e0cfae commit 2c1267f
Show file tree
Hide file tree
Showing 10 changed files with 142 additions and 44 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ image::https://raw.githubusercontent.com/apache/tinkerpop/master/docs/static/ima
[[release-3-6-7]]
=== TinkerPop 3.6.7 (NOT OFFICIALLY RELEASED YET)
* Fixed issue where server errors weren't being properly parsed when sending bytecode over HTTP.
[[release-3-6-6]]
=== TinkerPop 3.6.6 (November 20, 2023)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,33 +18,45 @@
*/
package org.apache.tinkerpop.gremlin.driver.handler;

import io.netty.buffer.ByteBufInputStream;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToMessageDecoder;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.util.CharsetUtil;
import io.netty.handler.codec.http.HttpResponseStatus;
import org.apache.tinkerpop.gremlin.driver.MessageSerializer;
import org.apache.tinkerpop.gremlin.driver.Tokens;
import org.apache.tinkerpop.gremlin.driver.message.ResponseMessage;
import org.apache.tinkerpop.gremlin.driver.ser.MessageTextSerializer;
import org.apache.tinkerpop.gremlin.driver.message.ResponseStatusCode;
import org.apache.tinkerpop.gremlin.driver.ser.SerTokens;
import org.apache.tinkerpop.shaded.jackson.databind.JsonNode;
import org.apache.tinkerpop.shaded.jackson.databind.ObjectMapper;

import java.util.List;
import java.util.UUID;

/**
* Converts {@code HttpResponse} to a {@link ResponseMessage}.
*/
@ChannelHandler.Sharable
public final class HttpGremlinResponseDecoder extends MessageToMessageDecoder<FullHttpResponse> {
private final MessageSerializer<?> serializer;
private final ObjectMapper mapper = new ObjectMapper();

public HttpGremlinResponseDecoder(final MessageSerializer<?> serializer) {
this.serializer = serializer;
}

@Override
protected void decode(final ChannelHandlerContext channelHandlerContext, final FullHttpResponse httpResponse, final List<Object> objects) throws Exception {
objects.add(serializer.deserializeResponse(httpResponse.content()));
if (httpResponse.status() == HttpResponseStatus.OK) {
objects.add(serializer.deserializeResponse(httpResponse.content()));
} else {
final JsonNode root = mapper.readTree(new ByteBufInputStream(httpResponse.content()));
objects.add(ResponseMessage.build(UUID.fromString(root.get(Tokens.REQUEST_ID).asText()))
.code(ResponseStatusCode.SERVER_ERROR)
.statusMessage(root.get(SerTokens.TOKEN_MESSAGE).asText())
.create());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ def read(self):
# Inner function to perform async read.
async def async_read():
async with async_timeout.timeout(self._read_timeout):
return await self._http_req_resp.read()
return {"content": await self._http_req_resp.read(),
"ok": self._http_req_resp.ok,
"status": self._http_req_resp.status}

return self._loop.run_until_complete(async_read())

Expand Down
46 changes: 24 additions & 22 deletions gremlin-python/src/main/python/gremlin_python/driver/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,35 +223,37 @@ def write(self, request_id, request_message):

self._transport.write(message)

def data_received(self, message, results_dict):
def data_received(self, response, results_dict):
# if Gremlin Server cuts off then we get a None for the message
if message is None:
if response is None:
log.error("Received empty message from server.")
raise GremlinServerError({'code': 500,
'message': 'Server disconnected - please try to reconnect', 'attributes': {}})

message = self._message_serializer.deserialize_message(message)
request_id = message['requestId']
result_set = results_dict[request_id] if request_id in results_dict else ResultSet(None, None)
status_code = message['status']['code']
aggregate_to = message['result']['meta'].get('aggregateTo', 'list')
data = message['result']['data']
result_set.aggregate_to = aggregate_to

if status_code == 204:
result_set.stream.put_nowait([])
del results_dict[request_id]
return status_code
elif status_code in [200, 206]:
result_set.stream.put_nowait(data)
if status_code == 200:
result_set.status_attributes = message['status']['attributes']
if response['ok']:
message = self._message_serializer.deserialize_message(response['content'])
request_id = message['requestId']
result_set = results_dict[request_id] if request_id in results_dict else ResultSet(None, None)
status_code = message['status']['code']
aggregate_to = message['result']['meta'].get('aggregateTo', 'list')
data = message['result']['data']
result_set.aggregate_to = aggregate_to

if status_code == 204:
result_set.stream.put_nowait([])
del results_dict[request_id]
return status_code
return status_code
elif status_code in [200, 206]:
result_set.stream.put_nowait(data)
if status_code == 200:
result_set.status_attributes = message['status']['attributes']
del results_dict[request_id]
return status_code
else:
# This message is going to be huge and kind of hard to read, but in the event of an error,
# it can provide invaluable info, so space it out appropriately.
log.error("\r\nReceived error message '%s'\r\n\r\nWith results dictionary '%s'",
str(message), str(results_dict))
del results_dict[request_id]
raise GremlinServerError(message['status'])
str(response['content']), str(results_dict))
body = json.loads(response['content'])
del results_dict[body['requestId']]
raise GremlinServerError({'code': response['status'], 'message': body['message'], 'attributes': {}})
21 changes: 21 additions & 0 deletions gremlin-python/src/main/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,24 @@ def fin():
request.addfinalizer(fin)
return remote_conn
"""


@pytest.fixture(params=['graphsonv3', 'graphbinaryv1'])
def invalid_alias_remote_connection_http(request):
try:
if request.param == 'graphbinaryv1':
remote_conn = DriverRemoteConnection(anonymous_url_http, 'does_not_exist',
message_serializer=serializer.GraphBinarySerializersV1())
elif request.param == 'graphsonv3':
remote_conn = DriverRemoteConnection(anonymous_url_http, 'does_not_exist',
message_serializer=serializer.GraphSONSerializersV3d0())
else:
raise ValueError("Invalid serializer option - " + request.param)
except OSError:
pytest.skip('Gremlin Server is not running')
else:
def fin():
remote_conn.close()

request.addfinalizer(fin)
return remote_conn
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from gremlin_python.process.strategies import SubgraphStrategy, SeedStrategy
from gremlin_python.structure.io.util import HashableDict
from gremlin_python.driver.serializer import GraphSONSerializersV2d0
from gremlin_python.driver.protocol import GremlinServerError

gremlin_server_url_http = os.environ.get('GREMLIN_SERVER_URL_HTTP', 'http://localhost:{}/')
test_no_auth_http_url = gremlin_server_url_http.format(45940)
Expand Down Expand Up @@ -212,6 +213,15 @@ def test_clone(self, remote_connection_http):
assert 5 == t.clone().limit(5).count().next()
assert 10 == t.clone().limit(10).count().next()

def test_receive_error(self, invalid_alias_remote_connection_http):
g = traversal().withRemote(invalid_alias_remote_connection_http)

try:
g.V().next()
assert False
except GremlinServerError as err:
assert err.status_code == 400
assert 'Could not rebind' in err.status_message

"""
# The WsAndHttpChannelizer somehow does not distinguish the ssl handlers so authenticated https remote connection
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,20 @@ public void channelRead(final ChannelHandlerContext ctx, final Object msg) {
if (msg instanceof FullHttpMessage){
final FullHttpMessage request = (FullHttpMessage) msg;
final boolean keepAlive = HttpUtil.isKeepAlive(request);
final RequestMessage requestMessage;
try {
requestMessage = HttpHandlerUtil.getRequestMessageFromHttpRequest((FullHttpRequest) request);
} catch (IllegalArgumentException iae) {
HttpHandlerUtil.sendError(ctx, BAD_REQUEST, iae.getMessage(), keepAlive);
return;
}

try {
user = ctx.channel().attr(StateKey.AUTHENTICATED_USER).get();
if (null == user) { // This is expected when using the AllowAllAuthenticator
user = AuthenticatedUser.ANONYMOUS_USER;
}
final RequestMessage requestMessage = HttpHandlerUtil.getRequestMessageFromHttpRequest((FullHttpRequest) request);

authorizer.authorize(user, requestMessage);
ctx.fireChannelRead(request);
} catch (AuthorizationException ex) { // Expected: users can alternate between allowed and disallowed requests
Expand All @@ -77,18 +84,18 @@ public void channelRead(final ChannelHandlerContext ctx, final Object msg) {
try {
script = HttpHandlerUtil.getRequestMessageFromHttpRequest((FullHttpRequest) request).getArgOrDefault(Tokens.ARGS_GREMLIN, "");
} catch (IllegalArgumentException iae) {
HttpHandlerUtil.sendError(ctx, BAD_REQUEST, iae.getMessage(), keepAlive);
HttpHandlerUtil.sendError(ctx, BAD_REQUEST, requestMessage.getRequestId(), iae.getMessage(), keepAlive);
return;
}
auditLogger.info("User {} with address {} attempted an unauthorized http request: {}",
user.getName(), address, script);
final String message = String.format("No authorization for script [%s] - check permissions.", script);
HttpHandlerUtil.sendError(ctx, UNAUTHORIZED, message, keepAlive);
HttpHandlerUtil.sendError(ctx, UNAUTHORIZED, requestMessage.getRequestId(), message, keepAlive);
ReferenceCountUtil.release(msg);
} catch (Exception ex) {
final String message = String.format(
"%s is not ready to handle requests - unknown error", authorizer.getClass().getSimpleName());
HttpHandlerUtil.sendError(ctx, INTERNAL_SERVER_ERROR, message, keepAlive);
HttpHandlerUtil.sendError(ctx, INTERNAL_SERVER_ERROR, requestMessage.getRequestId(), message, keepAlive);
ReferenceCountUtil.release(msg);
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicReference;
import java.util.regex.Matcher;
Expand Down Expand Up @@ -154,10 +155,11 @@ public void channelRead(final ChannelHandlerContext ctx, final Object msg) {
return;
}

final UUID requestId = requestMessage.getRequestId();
final String acceptMime = Optional.ofNullable(req.headers().get(HttpHeaderNames.ACCEPT)).orElse("application/json");
final Pair<String, MessageTextSerializer<?>> serializer = chooseSerializer(acceptMime);
if (null == serializer) {
HttpHandlerUtil.sendError(ctx, BAD_REQUEST, String.format("no serializer for requested Accept header: %s", acceptMime),
HttpHandlerUtil.sendError(ctx, BAD_REQUEST, requestId, String.format("no serializer for requested Accept header: %s", acceptMime),
keepAlive);
ReferenceCountUtil.release(msg);
return;
Expand Down Expand Up @@ -210,7 +212,7 @@ public void channelRead(final ChannelHandlerContext ctx, final Object msg) {
bindings = createBindings(requestMessage.getArgOrDefault(Tokens.ARGS_BINDINGS, Collections.emptyMap()),
requestMessage.getArgOrDefault(Tokens.ARGS_ALIASES, Collections.emptyMap()));
} catch (IllegalStateException iae) {
HttpHandlerUtil.sendError(ctx, BAD_REQUEST, iae.getMessage(), keepAlive);
HttpHandlerUtil.sendError(ctx, BAD_REQUEST, requestId, iae.getMessage(), keepAlive);
ReferenceCountUtil.release(msg);
return;
}
Expand All @@ -237,7 +239,7 @@ public void channelRead(final ChannelHandlerContext ctx, final Object msg) {
final List<Object> results = requestMessage.getOp().equals(Tokens.OPS_BYTECODE) ?
(List<Object>) IteratorUtils.asList(o).stream().map(r -> new DefaultRemoteTraverser<Object>(r, 1)).collect(Collectors.toList()) :
IteratorUtils.asList(o);
final ResponseMessage responseMessage = ResponseMessage.build(requestMessage.getRequestId())
final ResponseMessage responseMessage = ResponseMessage.build(requestId)
.code(ResponseStatusCode.SUCCESS)
.result(results).create();

Expand Down Expand Up @@ -269,9 +271,9 @@ public void channelRead(final ChannelHandlerContext ctx, final Object msg) {

evalFuture.exceptionally(t -> {
if (t.getMessage() != null)
HttpHandlerUtil.sendError(ctx, INTERNAL_SERVER_ERROR, t.getMessage(), Optional.of(t), keepAlive);
HttpHandlerUtil.sendError(ctx, INTERNAL_SERVER_ERROR, requestId, t.getMessage(), Optional.of(t), keepAlive);
else
HttpHandlerUtil.sendError(ctx, INTERNAL_SERVER_ERROR, String.format("Error encountered evaluating script: %s",
HttpHandlerUtil.sendError(ctx, INTERNAL_SERVER_ERROR, requestId, String.format("Error encountered evaluating script: %s",
requestMessage.getArg(Tokens.ARGS_GREMLIN))
, Optional.of(t), keepAlive);
promise.setFailure(t);
Expand All @@ -289,11 +291,11 @@ public void channelRead(final ChannelHandlerContext ctx, final Object msg) {
// context on whether to close the connection or not, based on keepalive.
final Throwable t = ExceptionHelper.getRootCause(ex);
if (t instanceof TooLongFrameException) {
HttpHandlerUtil.sendError(ctx, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, t.getMessage() + " - increase the maxContentLength", keepAlive);
HttpHandlerUtil.sendError(ctx, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, requestId, t.getMessage() + " - increase the maxContentLength", keepAlive);
} else if (t != null){
HttpHandlerUtil.sendError(ctx, INTERNAL_SERVER_ERROR, t.getMessage(), keepAlive);
HttpHandlerUtil.sendError(ctx, INTERNAL_SERVER_ERROR, requestId, t.getMessage(), keepAlive);
} else {
HttpHandlerUtil.sendError(ctx, INTERNAL_SERVER_ERROR, ex.getMessage(), keepAlive);
HttpHandlerUtil.sendError(ctx, INTERNAL_SERVER_ERROR, requestId, ex.getMessage(), keepAlive);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,15 @@ else if (node.isBoolean())

static void sendError(final ChannelHandlerContext ctx, final HttpResponseStatus status,
final String message, final boolean keepAlive) {
sendError(ctx, status, message, Optional.empty(), keepAlive);
sendError(ctx, status, null, message, Optional.empty(), keepAlive);
}

static void sendError(final ChannelHandlerContext ctx, final HttpResponseStatus status,
static void sendError(final ChannelHandlerContext ctx, final HttpResponseStatus status, final UUID requestId,
final String message, final boolean keepAlive) {
sendError(ctx, status, requestId, message, Optional.empty(), keepAlive);
}

static void sendError(final ChannelHandlerContext ctx, final HttpResponseStatus status, final UUID requestId,
final String message, final Optional<Throwable> t, final boolean keepAlive) {
if (t.isPresent())
logger.warn(String.format("Invalid request - responding with %s and %s", status, message), t.get());
Expand All @@ -251,6 +256,9 @@ static void sendError(final ChannelHandlerContext ctx, final HttpResponseStatus
ExceptionUtils.getThrowableList(t.get()).forEach(throwable -> exceptionList.add(throwable.getClass().getName()));
node.put(Tokens.STATUS_ATTRIBUTE_STACK_TRACE, ExceptionUtils.getStackTrace(t.get()));
}
if (requestId != null) {
node.put("requestId", requestId.toString());
}

final FullHttpResponse response = new DefaultFullHttpResponse(
HTTP_1_1, status, Unpooled.copiedBuffer(node.toString(), CharsetUtil.UTF_8));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,4 +205,38 @@ public void shouldFailToUseTx() throws Exception {
cluster.close();
}
}

@Test
public void shouldDeserializeErrorWithGraphBinary() throws Exception {
final Cluster cluster = TestClientFactory.build()
.channelizer(Channelizer.HttpChannelizer.class)
.serializer(Serializers.GRAPHBINARY_V1D0)
.create();
try {
final GraphTraversalSource g = traversal().withRemote(DriverRemoteConnection.using(cluster, "doesNotExist"));
g.V().next();
fail("Expected exception to be thrown.");
} catch (Exception ex) {
assert ex.getMessage().contains("Could not rebind");
} finally {
cluster.close();
}
}

@Test
public void shouldDeserializeErrorWithGraphSON() throws Exception {
final Cluster cluster = TestClientFactory.build()
.channelizer(Channelizer.HttpChannelizer.class)
.serializer(Serializers.GRAPHSON_V3D0)
.create();
try {
final GraphTraversalSource g = traversal().withRemote(DriverRemoteConnection.using(cluster, "doesNotExist"));
g.V().next();
fail("Expected exception to be thrown.");
} catch (Exception ex) {
assert ex.getMessage().contains("Could not rebind");
} finally {
cluster.close();
}
}
}

0 comments on commit 2c1267f

Please sign in to comment.