Skip to content

Commit

Permalink
Set X-Forwarded-Host when client using HTTP/2
Browse files Browse the repository at this point in the history
  • Loading branch information
oneonestar authored and mosabua committed Aug 30, 2024
1 parent 2050251 commit 2665198
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
import java.util.concurrent.ExecutorService;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.net.HttpHeaders.HOST;
import static com.google.common.net.HttpHeaders.VIA;
import static com.google.common.net.HttpHeaders.X_FORWARDED_FOR;
import static com.google.common.net.HttpHeaders.X_FORWARDED_HOST;
Expand Down Expand Up @@ -290,9 +289,9 @@ private void addXForwardedHeaders(HttpServletRequest servletRequest, Request.Bui
requestBuilder.addHeader(X_FORWARDED_FOR, servletRequest.getRemoteAddr());
requestBuilder.addHeader(X_FORWARDED_PROTO, servletRequest.getScheme());
requestBuilder.addHeader(X_FORWARDED_PORT, String.valueOf(servletRequest.getServerPort()));
String hostHeader = servletRequest.getHeader(HOST);
if (hostHeader != null) {
requestBuilder.addHeader(X_FORWARDED_HOST, hostHeader);
String serverName = servletRequest.getServerName();
if (serverName != null) {
requestBuilder.addHeader(X_FORWARDED_HOST, serverName);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,30 @@
package io.trino.gateway.ha;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList;
import io.trino.gateway.ha.config.ProxyBackendConfiguration;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Protocol;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.TestInstance.Lifecycle;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.testcontainers.containers.TrinoContainer;

import java.util.stream.Stream;

import static org.assertj.core.api.Assertions.assertThat;
import static org.testcontainers.utility.MountableFile.forClasspathResource;

@TestInstance(Lifecycle.PER_CLASS)
public class TestGatewayHaSingleBackend
{
private final OkHttpClient httpClient = new OkHttpClient();
private TrinoContainer trino;
int routerPort = 21001 + (int) (Math.random() * 1000);

Expand All @@ -58,8 +62,18 @@ public void setup()
"trino1", "http://localhost:" + backendPort, "externalUrl", true, "adhoc", routerPort);
}

@Test
public void testRequestDelivery()
public Stream<OkHttpClient> getOkHttpClient()
{
OkHttpClient.Builder http1Builder = new OkHttpClient.Builder();
http1Builder.protocols(ImmutableList.of(Protocol.HTTP_1_1));
OkHttpClient.Builder http2Builder = new OkHttpClient.Builder();
http2Builder.protocols(ImmutableList.of(Protocol.H2_PRIOR_KNOWLEDGE));
return Stream.of(http1Builder.build(), http2Builder.build());
}

@ParameterizedTest
@MethodSource("getOkHttpClient")
public void testRequestDelivery(OkHttpClient httpClient)
throws Exception
{
RequestBody requestBody =
Expand All @@ -68,18 +82,23 @@ public void testRequestDelivery()
new Request.Builder()
.url("http://localhost:" + routerPort + "/v1/statement")
.addHeader("X-Trino-User", "test")
.addHeader("Host", "test.host.com")
.post(requestBody)
.build();
Response response = httpClient.newCall(request).execute();
assertThat(response.body().string()).contains("nextUri");
String responseBody = response.body().string();
assertThat(responseBody).contains("nextUri");
assertThat(responseBody).contains("test.host.com");
}

@Test
public void testBackendConfiguration()
@ParameterizedTest
@MethodSource("getOkHttpClient")
public void testBackendConfiguration(OkHttpClient httpClient)
throws Exception
{
Request request = new Request.Builder()
.url("http://localhost:" + routerPort + "/entity/GATEWAY_BACKEND")
.addHeader("Host", "test.host.com")
.method("GET", null)
.build();
Response response = httpClient.newCall(request).execute();
Expand Down

0 comments on commit 2665198

Please sign in to comment.