Skip to content

Commit

Permalink
Allow multiple reads of request body
Browse files Browse the repository at this point in the history
  • Loading branch information
willmostly authored Aug 1, 2024
1 parent 6011726 commit e6b45cf
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.gateway.proxyserver;

import jakarta.servlet.ReadListener;
import jakarta.servlet.ServletInputStream;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletRequestWrapper;

import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.StringReader;
import java.nio.charset.StandardCharsets;

public class MultiReadHttpServletRequest
extends HttpServletRequestWrapper
{
private final byte[] content;

public MultiReadHttpServletRequest(HttpServletRequest request, String body)
{
super(request);
content = body.getBytes(StandardCharsets.UTF_8);
}

@Override
public ServletInputStream getInputStream()
throws IOException
{
final ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(content);
return new ServletInputStream()
{
@Override
public boolean isFinished()
{
return byteArrayInputStream.available() > 0;
}

@Override
public boolean isReady()
{
return false;
}

@Override
public void setReadListener(ReadListener readListener) {}

@Override
public int read()
throws IOException
{
return byteArrayInputStream.read();
}
};
}

@Override
public BufferedReader getReader()
throws IOException
{
return new BufferedReader(new StringReader(new String(content, StandardCharsets.UTF_8)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,12 @@ public void postHandler(
@Context HttpServletRequest servletRequest,
@Suspended AsyncResponse asyncResponse)
{
if (servletRequest.getRequestURI().startsWith(V1_STATEMENT_PATH)) {
MultiReadHttpServletRequest multiReadHttpServletRequest = new MultiReadHttpServletRequest(servletRequest, body);
if (multiReadHttpServletRequest.getRequestURI().startsWith(V1_STATEMENT_PATH)) {
proxyHandlerStats.recordRequest();
}
String remoteUri = routingTargetHandler.getRoutingDestination(servletRequest);
proxyRequestHandler.postRequest(body, servletRequest, asyncResponse, URI.create(remoteUri));
String remoteUri = routingTargetHandler.getRoutingDestination(multiReadHttpServletRequest);
proxyRequestHandler.postRequest(body, multiReadHttpServletRequest, asyncResponse, URI.create(remoteUri));
}

@GET
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ public static TestConfig buildGatewayConfigAndSeedDb(int routerPort, String conf
.replace(
"APPLICATION_CONNECTOR_PORT", String.valueOf(30000 + (int) (Math.random() * 1000)))
.replace("ADMIN_CONNECTOR_PORT", String.valueOf(31000 + (int) (Math.random() * 1000)))
.replace("LOCALHOST_JKS", Paths.get(resource.toURI()).toFile().getAbsolutePath());
.replace("LOCALHOST_JKS", Paths.get(resource.toURI()).toFile().getAbsolutePath())
.replace("RESOURCES_DIR", Paths.get("src", "test", "resources").toFile().getAbsolutePath());

File target = File.createTempFile("config-" + System.currentTimeMillis(), "config.yaml");

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.gateway.ha;

import com.fasterxml.jackson.databind.ObjectMapper;
import io.trino.gateway.ha.config.ProxyBackendConfiguration;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.testcontainers.containers.TrinoContainer;

import static org.assertj.core.api.Assertions.assertThat;
import static org.testcontainers.utility.MountableFile.forClasspathResource;

@TestInstance(TestInstance.Lifecycle.PER_CLASS)
public class TestGatewayHaWithRoutingRulesSingleBackend
{
private final OkHttpClient httpClient = new OkHttpClient();
private TrinoContainer trino;
int routerPort = 21001 + (int) (Math.random() * 1000);

@BeforeAll
public void setup()
throws Exception
{
trino = new TrinoContainer("trinodb/trino");
trino.withCopyFileToContainer(forClasspathResource("trino-config.properties"), "/etc/trino/config.properties");
trino.start();

int backendPort = trino.getMappedPort(8080);

// seed database
HaGatewayTestUtils.TestConfig testConfig =
HaGatewayTestUtils.buildGatewayConfigAndSeedDb(routerPort, "test-config-with-routing-template.yml");
// Start Gateway
String[] args = {testConfig.configFilePath()};
HaGatewayLauncher.main(args);
// Now populate the backend
HaGatewayTestUtils.setUpBackend(
"trino1", "http://localhost:" + backendPort, "externalUrl", true, "system", routerPort);
}

@Test
public void testRequestDelivery()
throws Exception
{
RequestBody requestBody =
RequestBody.create(MediaType.parse("application/json; charset=utf-8"), "SELECT * from system.runtime.nodes");
Request request =
new Request.Builder()
.url("http://localhost:" + routerPort + "/v1/statement")
.addHeader("X-Trino-User", "test")
.post(requestBody)
.build();
Response response = httpClient.newCall(request).execute();
assertThat(response.body().string()).contains("nextUri");
}

// Do not allow trino gateway to fall back to the adhoc routing group if the desired backend is not found
@Test
public void testVerifyNoAdhoc()
throws Exception
{
Request request = new Request.Builder()
.url("http://localhost:" + routerPort + "/entity/GATEWAY_BACKEND")
.method("GET", null)
.build();
Response response = httpClient.newCall(request).execute();

final ObjectMapper objectMapper = new ObjectMapper();
ProxyBackendConfiguration[] backendConfiguration =
objectMapper.readValue(response.body().string(), ProxyBackendConfiguration[].class);

assertThat(backendConfiguration).hasSize(1);
assertThat(backendConfiguration[0].isActive()).isTrue();
assertThat(backendConfiguration[0].getRoutingGroup()).isNotEqualTo("adhoc");
}

@AfterAll
public void cleanup()
{
trino.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ condition: |
actions:
- "result.put(\"routingGroup\", \"defaults-group\")"
---
name: "system-group"
description: "capture queries to system catalog"
condition: |
trinoQueryProperties.getCatalogs().contains("system")
actions:
- "result.put(\"routingGroup\", \"system\")"

---
name: "nomatch"
priority: -1
description: "default group to catch if no other rules fired"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
serverConfig:
node.environment: test
http-server.http.port: REQUEST_ROUTER_PORT

dataStore:
jdbcUrl: jdbc:h2:DB_FILE_PATH
user: sa
password: sa
driver: org.h2.Driver

modules:
- io.trino.gateway.ha.module.HaGatewayProviderModule

extraWhitelistPaths:
- '/v1/custom.*'
- '/custom/logout.*'

gatewayCookieConfiguration:
enabled: true
cookieSigningSecret: "kjlhbfrewbyuo452cds3dc1234ancdsjh"

oauth2GatewayCookieConfiguration:
deletePaths:
- "/custom/logout"

requestAnalyzerConfig:
analyzeRequest: true

routingRules:
rulesEngineEnabled: true
rulesConfigPath: "RESOURCES_DIR/rules/routing_rules_trino_query_properties.yml"

0 comments on commit e6b45cf

Please sign in to comment.