From 2665198a53de4a63ce250623de9386cdd8e33cbc Mon Sep 17 00:00:00 2001 From: Star Poon Date: Fri, 30 Aug 2024 19:08:03 +0900 Subject: [PATCH] Set X-Forwarded-Host when client using HTTP/2 --- .../proxyserver/ProxyRequestHandler.java | 7 ++-- .../ha/TestGatewayHaSingleBackend.java | 33 +++++++++++++++---- 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyRequestHandler.java b/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyRequestHandler.java index d6bf36174..c10b354ab 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyRequestHandler.java +++ b/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyRequestHandler.java @@ -45,7 +45,6 @@ import java.util.concurrent.ExecutorService; import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.net.HttpHeaders.HOST; import static com.google.common.net.HttpHeaders.VIA; import static com.google.common.net.HttpHeaders.X_FORWARDED_FOR; import static com.google.common.net.HttpHeaders.X_FORWARDED_HOST; @@ -290,9 +289,9 @@ private void addXForwardedHeaders(HttpServletRequest servletRequest, Request.Bui requestBuilder.addHeader(X_FORWARDED_FOR, servletRequest.getRemoteAddr()); requestBuilder.addHeader(X_FORWARDED_PROTO, servletRequest.getScheme()); requestBuilder.addHeader(X_FORWARDED_PORT, String.valueOf(servletRequest.getServerPort())); - String hostHeader = servletRequest.getHeader(HOST); - if (hostHeader != null) { - requestBuilder.addHeader(X_FORWARDED_HOST, hostHeader); + String serverName = servletRequest.getServerName(); + if (serverName != null) { + requestBuilder.addHeader(X_FORWARDED_HOST, serverName); } } } diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/TestGatewayHaSingleBackend.java b/gateway-ha/src/test/java/io/trino/gateway/ha/TestGatewayHaSingleBackend.java index 2a7b7282c..31f2cd598 100644 --- a/gateway-ha/src/test/java/io/trino/gateway/ha/TestGatewayHaSingleBackend.java +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/TestGatewayHaSingleBackend.java @@ -14,26 +14,30 @@ package io.trino.gateway.ha; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; import io.trino.gateway.ha.config.ProxyBackendConfiguration; import okhttp3.MediaType; import okhttp3.OkHttpClient; +import okhttp3.Protocol; import okhttp3.Request; import okhttp3.RequestBody; import okhttp3.Response; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.TestInstance.Lifecycle; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.testcontainers.containers.TrinoContainer; +import java.util.stream.Stream; + import static org.assertj.core.api.Assertions.assertThat; import static org.testcontainers.utility.MountableFile.forClasspathResource; @TestInstance(Lifecycle.PER_CLASS) public class TestGatewayHaSingleBackend { - private final OkHttpClient httpClient = new OkHttpClient(); private TrinoContainer trino; int routerPort = 21001 + (int) (Math.random() * 1000); @@ -58,8 +62,18 @@ public void setup() "trino1", "http://localhost:" + backendPort, "externalUrl", true, "adhoc", routerPort); } - @Test - public void testRequestDelivery() + public Stream getOkHttpClient() + { + OkHttpClient.Builder http1Builder = new OkHttpClient.Builder(); + http1Builder.protocols(ImmutableList.of(Protocol.HTTP_1_1)); + OkHttpClient.Builder http2Builder = new OkHttpClient.Builder(); + http2Builder.protocols(ImmutableList.of(Protocol.H2_PRIOR_KNOWLEDGE)); + return Stream.of(http1Builder.build(), http2Builder.build()); + } + + @ParameterizedTest + @MethodSource("getOkHttpClient") + public void testRequestDelivery(OkHttpClient httpClient) throws Exception { RequestBody requestBody = @@ -68,18 +82,23 @@ public void testRequestDelivery() new Request.Builder() .url("http://localhost:" + routerPort + "/v1/statement") .addHeader("X-Trino-User", "test") + .addHeader("Host", "test.host.com") .post(requestBody) .build(); Response response = httpClient.newCall(request).execute(); - assertThat(response.body().string()).contains("nextUri"); + String responseBody = response.body().string(); + assertThat(responseBody).contains("nextUri"); + assertThat(responseBody).contains("test.host.com"); } - @Test - public void testBackendConfiguration() + @ParameterizedTest + @MethodSource("getOkHttpClient") + public void testBackendConfiguration(OkHttpClient httpClient) throws Exception { Request request = new Request.Builder() .url("http://localhost:" + routerPort + "/entity/GATEWAY_BACKEND") + .addHeader("Host", "test.host.com") .method("GET", null) .build(); Response response = httpClient.newCall(request).execute();