diff --git a/services-gateway-client-transport/src/main/java/io/scalecube/services/gateway/transport/GatewayClientSettings.java b/services-gateway-client-transport/src/main/java/io/scalecube/services/gateway/transport/GatewayClientSettings.java index 230450a8..c02acbc5 100644 --- a/services-gateway-client-transport/src/main/java/io/scalecube/services/gateway/transport/GatewayClientSettings.java +++ b/services-gateway-client-transport/src/main/java/io/scalecube/services/gateway/transport/GatewayClientSettings.java @@ -4,6 +4,9 @@ import io.scalecube.services.exceptions.DefaultErrorMapper; import io.scalecube.services.exceptions.ServiceClientErrorMapper; import java.time.Duration; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import reactor.netty.tcp.SslProvider; public class GatewayClientSettings { @@ -20,6 +23,7 @@ public class GatewayClientSettings { private final ServiceClientErrorMapper errorMapper; private final Duration keepAliveInterval; private final boolean wiretap; + private final Map headers; private GatewayClientSettings(Builder builder) { this.host = builder.host; @@ -30,6 +34,7 @@ private GatewayClientSettings(Builder builder) { this.errorMapper = builder.errorMapper; this.keepAliveInterval = builder.keepAliveInterval; this.wiretap = builder.wiretap; + this.headers = builder.headers; } public String host() { @@ -64,6 +69,10 @@ public boolean wiretap() { return this.wiretap; } + public Map headers() { + return headers; + } + public static Builder builder() { return new Builder(); } @@ -96,9 +105,9 @@ public static class Builder { private ServiceClientErrorMapper errorMapper = DefaultErrorMapper.INSTANCE; private Duration keepAliveInterval = DEFAULT_KEEPALIVE_INTERVAL; private boolean wiretap = false; + private Map headers = Collections.emptyMap(); - private Builder() { - } + private Builder() {} private Builder(GatewayClientSettings originalSettings) { this.host = originalSettings.host; @@ -109,6 +118,7 @@ private Builder(GatewayClientSettings originalSettings) { this.errorMapper = originalSettings.errorMapper; this.keepAliveInterval = originalSettings.keepAliveInterval; this.wiretap = originalSettings.wiretap; + this.headers = Collections.unmodifiableMap(new HashMap<>(originalSettings.headers)); } public Builder host(String host) { @@ -191,6 +201,11 @@ public Builder errorMapper(ServiceClientErrorMapper errorMapper) { return this; } + public Builder headers(Map headers) { + this.headers = Collections.unmodifiableMap(new HashMap<>(headers)); + return this; + } + public GatewayClientSettings build() { return new GatewayClientSettings(this); } diff --git a/services-gateway-client-transport/src/main/java/io/scalecube/services/gateway/transport/http/HttpGatewayClient.java b/services-gateway-client-transport/src/main/java/io/scalecube/services/gateway/transport/http/HttpGatewayClient.java index 831048d9..0fac2bbd 100644 --- a/services-gateway-client-transport/src/main/java/io/scalecube/services/gateway/transport/http/HttpGatewayClient.java +++ b/services-gateway-client-transport/src/main/java/io/scalecube/services/gateway/transport/http/HttpGatewayClient.java @@ -41,6 +41,7 @@ public HttpGatewayClient(GatewayClientSettings settings, GatewayClientCodec settings.headers().forEach(headers::add)) .followRedirect(settings.followRedirect()) .tcpConfiguration( tcpClient -> { diff --git a/services-gateway-client-transport/src/main/java/io/scalecube/services/gateway/transport/rsocket/RSocketGatewayClient.java b/services-gateway-client-transport/src/main/java/io/scalecube/services/gateway/transport/rsocket/RSocketGatewayClient.java index 107574d0..98e3a63a 100644 --- a/services-gateway-client-transport/src/main/java/io/scalecube/services/gateway/transport/rsocket/RSocketGatewayClient.java +++ b/services-gateway-client-transport/src/main/java/io/scalecube/services/gateway/transport/rsocket/RSocketGatewayClient.java @@ -5,6 +5,7 @@ import io.rsocket.core.RSocketConnector; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.util.EmptyPayload; import io.scalecube.services.api.ServiceMessage; import io.scalecube.services.exceptions.ConnectionClosedException; import io.scalecube.services.gateway.transport.GatewayClient; @@ -135,8 +136,14 @@ private Mono getOrConnect0(Mono prev) { return prev; } + Payload setupPayload = EmptyPayload.INSTANCE; + if (!settings.headers().isEmpty()) { + setupPayload = codec.encode(ServiceMessage.builder().headers(settings.headers()).build()); + } + return RSocketConnector.create() .payloadDecoder(PayloadDecoder.DEFAULT) + .setupPayload(setupPayload) .metadataMimeType(settings.contentType()) .connect(createRSocketTransport(settings)) .doOnSuccess( diff --git a/services-gateway-client-transport/src/main/java/io/scalecube/services/gateway/transport/websocket/WebsocketGatewayClient.java b/services-gateway-client-transport/src/main/java/io/scalecube/services/gateway/transport/websocket/WebsocketGatewayClient.java index 905792fc..432657de 100644 --- a/services-gateway-client-transport/src/main/java/io/scalecube/services/gateway/transport/websocket/WebsocketGatewayClient.java +++ b/services-gateway-client-transport/src/main/java/io/scalecube/services/gateway/transport/websocket/WebsocketGatewayClient.java @@ -55,6 +55,7 @@ public WebsocketGatewayClient(GatewayClientSettings settings, GatewayClientCodec httpClient = HttpClient.newConnection() + .headers(headers -> settings.headers().forEach(headers::add)) .followRedirect(settings.followRedirect()) .tcpConfiguration( tcpClient -> { diff --git a/services-gateway-netty/src/main/java/io/scalecube/services/gateway/GatewaySession.java b/services-gateway-netty/src/main/java/io/scalecube/services/gateway/GatewaySession.java index 871c66f4..aa3b3d95 100644 --- a/services-gateway-netty/src/main/java/io/scalecube/services/gateway/GatewaySession.java +++ b/services-gateway-netty/src/main/java/io/scalecube/services/gateway/GatewaySession.java @@ -1,6 +1,5 @@ package io.scalecube.services.gateway; -import java.util.List; import java.util.Map; public interface GatewaySession { @@ -15,7 +14,7 @@ public interface GatewaySession { /** * Returns headers associated with session. * - * @return heades map + * @return headers map */ - Map> headers(); + Map headers(); } diff --git a/services-gateway-netty/src/main/java/io/scalecube/services/gateway/GatewaySessionHandler.java b/services-gateway-netty/src/main/java/io/scalecube/services/gateway/GatewaySessionHandler.java index 71ac8e33..650894d5 100644 --- a/services-gateway-netty/src/main/java/io/scalecube/services/gateway/GatewaySessionHandler.java +++ b/services-gateway-netty/src/main/java/io/scalecube/services/gateway/GatewaySessionHandler.java @@ -2,7 +2,6 @@ import io.netty.buffer.ByteBuf; import io.scalecube.services.api.ServiceMessage; -import java.util.List; import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -84,7 +83,7 @@ default void onSessionError(GatewaySession session, Throwable throwable) { * @param headers connection/session headers * @return mono result */ - default Mono onConnectionOpen(long sessionId, Map> headers) { + default Mono onConnectionOpen(long sessionId, Map headers) { return Mono.fromRunnable( () -> LOGGER.debug( diff --git a/services-gateway-netty/src/main/java/io/scalecube/services/gateway/rsocket/RSocketGatewayAcceptor.java b/services-gateway-netty/src/main/java/io/scalecube/services/gateway/rsocket/RSocketGatewayAcceptor.java index f49d246f..a7732f27 100644 --- a/services-gateway-netty/src/main/java/io/scalecube/services/gateway/rsocket/RSocketGatewayAcceptor.java +++ b/services-gateway-netty/src/main/java/io/scalecube/services/gateway/rsocket/RSocketGatewayAcceptor.java @@ -7,6 +7,7 @@ import io.scalecube.services.gateway.GatewaySessionHandler; import io.scalecube.services.gateway.ServiceMessageCodec; import io.scalecube.services.transport.api.HeadersCodec; +import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Mono; @@ -41,6 +42,7 @@ public Mono accept(ConnectionSetupPayload setup, RSocket rsocket) { new RSocketGatewaySession( serviceCall, messageCodec, + headers(messageCodec, setup), (session, req) -> sessionHandler.mapMessage(session, req, Context.empty())); sessionHandler.onSessionOpen(gatewaySession); rsocket @@ -54,4 +56,11 @@ public Mono accept(ConnectionSetupPayload setup, RSocket rsocket) { return Mono.just(gatewaySession); } + + private Map headers( + ServiceMessageCodec messageCodec, ConnectionSetupPayload setup) { + return messageCodec + .decode(setup.sliceData().retain(), setup.sliceMetadata().retain()) + .headers(); + } } diff --git a/services-gateway-netty/src/main/java/io/scalecube/services/gateway/rsocket/RSocketGatewaySession.java b/services-gateway-netty/src/main/java/io/scalecube/services/gateway/rsocket/RSocketGatewaySession.java index 9bdf50c5..85e713b5 100644 --- a/services-gateway-netty/src/main/java/io/scalecube/services/gateway/rsocket/RSocketGatewaySession.java +++ b/services-gateway-netty/src/main/java/io/scalecube/services/gateway/rsocket/RSocketGatewaySession.java @@ -10,7 +10,7 @@ import io.scalecube.services.gateway.ReferenceCountUtil; import io.scalecube.services.gateway.ServiceMessageCodec; import java.util.Collections; -import java.util.List; +import java.util.HashMap; import java.util.Map; import java.util.concurrent.atomic.AtomicLong; import java.util.function.BiFunction; @@ -30,6 +30,7 @@ public final class RSocketGatewaySession extends AbstractRSocket implements Gate private final ServiceMessageCodec messageCodec; private final long sessionId; private final BiFunction messageMapper; + private final Map headers; /** * Constructor for gateway rsocket. @@ -40,11 +41,13 @@ public final class RSocketGatewaySession extends AbstractRSocket implements Gate public RSocketGatewaySession( ServiceCall serviceCall, ServiceMessageCodec messageCodec, + Map headers, BiFunction messageMapper) { this.serviceCall = serviceCall; this.messageCodec = messageCodec; this.messageMapper = messageMapper; this.sessionId = SESSION_ID_GENERATOR.incrementAndGet(); + this.headers = Collections.unmodifiableMap(new HashMap<>(headers)); } @Override @@ -53,8 +56,8 @@ public long sessionId() { } @Override - public Map> headers() { - return Collections.emptyMap(); + public Map headers() { + return headers; } @Override diff --git a/services-gateway-netty/src/main/java/io/scalecube/services/gateway/ws/WebsocketGatewayAcceptor.java b/services-gateway-netty/src/main/java/io/scalecube/services/gateway/ws/WebsocketGatewayAcceptor.java index 32f0f700..83da2c0a 100644 --- a/services-gateway-netty/src/main/java/io/scalecube/services/gateway/ws/WebsocketGatewayAcceptor.java +++ b/services-gateway-netty/src/main/java/io/scalecube/services/gateway/ws/WebsocketGatewayAcceptor.java @@ -21,14 +21,14 @@ import io.scalecube.services.exceptions.UnauthorizedException; import io.scalecube.services.gateway.GatewaySessionHandler; import io.scalecube.services.gateway.ReferenceCountUtil; -import java.util.HashMap; -import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.Objects; import java.util.Optional; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import java.util.function.BiFunction; +import java.util.stream.Collectors; import org.reactivestreams.Publisher; import reactor.core.Disposable; import reactor.core.publisher.Flux; @@ -63,7 +63,7 @@ public WebsocketGatewayAcceptor(ServiceCall serviceCall, GatewaySessionHandler g @Override public Publisher apply(HttpServerRequest httpRequest, HttpServerResponse httpResponse) { - final Map> headers = computeHeaders(httpRequest.requestHeaders()); + final Map headers = computeHeaders(httpRequest.requestHeaders()); final long sessionId = SESSION_ID_GENERATOR.incrementAndGet(); return gatewayHandler @@ -85,12 +85,9 @@ public Publisher apply(HttpServerRequest httpRequest, HttpServerResponse h .onErrorResume(throwable -> Mono.empty()); } - private static Map> computeHeaders(HttpHeaders httpHeaders) { - Map> headers = new HashMap<>(); - for (String name : httpHeaders.names()) { - headers.put(name, httpHeaders.getAll(name)); - } - return headers; + private static Map computeHeaders(HttpHeaders httpHeaders) { + // exception will be thrown on duplicate + return httpHeaders.entries().stream().collect(Collectors.toMap(Entry::getKey, Entry::getValue)); } private static int toStatusCode(Throwable throwable) { diff --git a/services-gateway-netty/src/main/java/io/scalecube/services/gateway/ws/WebsocketGatewaySession.java b/services-gateway-netty/src/main/java/io/scalecube/services/gateway/ws/WebsocketGatewaySession.java index ff04690c..b21f0994 100644 --- a/services-gateway-netty/src/main/java/io/scalecube/services/gateway/ws/WebsocketGatewaySession.java +++ b/services-gateway-netty/src/main/java/io/scalecube/services/gateway/ws/WebsocketGatewaySession.java @@ -9,7 +9,6 @@ import io.scalecube.services.gateway.GatewaySessionHandler; import java.util.Collections; import java.util.HashMap; -import java.util.List; import java.util.Map; import org.jctools.maps.NonBlockingHashMapLong; import org.slf4j.Logger; @@ -33,7 +32,7 @@ public final class WebsocketGatewaySession implements GatewaySession { private final WebsocketServiceMessageCodec codec; private final long sessionId; - private final Map> headers; + private final Map headers; /** * Create a new websocket session with given handshake, inbound and outbound channels. @@ -48,7 +47,7 @@ public final class WebsocketGatewaySession implements GatewaySession { public WebsocketGatewaySession( long sessionId, WebsocketServiceMessageCodec codec, - Map> headers, + Map headers, WebsocketInbound inbound, WebsocketOutbound outbound, GatewaySessionHandler gatewayHandler) { @@ -68,7 +67,7 @@ public long sessionId() { } @Override - public Map> headers() { + public Map headers() { return headers; } diff --git a/services-gateway-tests/src/test/java/io/scalecube/services/gateway/TestGatewaySessionHandler.java b/services-gateway-tests/src/test/java/io/scalecube/services/gateway/TestGatewaySessionHandler.java index 84a30ac7..e4562add 100644 --- a/services-gateway-tests/src/test/java/io/scalecube/services/gateway/TestGatewaySessionHandler.java +++ b/services-gateway-tests/src/test/java/io/scalecube/services/gateway/TestGatewaySessionHandler.java @@ -2,6 +2,7 @@ import io.scalecube.services.api.ServiceMessage; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; import reactor.util.context.Context; public class TestGatewaySessionHandler implements GatewaySessionHandler { @@ -9,6 +10,7 @@ public class TestGatewaySessionHandler implements GatewaySessionHandler { public final CountDownLatch msgLatch = new CountDownLatch(1); public final CountDownLatch connLatch = new CountDownLatch(1); public final CountDownLatch disconnLatch = new CountDownLatch(1); + private final AtomicReference lastSession = new AtomicReference<>(); @Override public ServiceMessage mapMessage(GatewaySession s, ServiceMessage req, Context context) { @@ -19,10 +21,15 @@ public ServiceMessage mapMessage(GatewaySession s, ServiceMessage req, Context c @Override public void onSessionOpen(GatewaySession s) { connLatch.countDown(); + lastSession.set(s); } @Override public void onSessionClose(GatewaySession s) { disconnLatch.countDown(); } + + public GatewaySession lastSession() { + return lastSession.get(); + } } diff --git a/services-gateway-tests/src/test/java/io/scalecube/services/gateway/rsocket/RsocketClientConnectionTest.java b/services-gateway-tests/src/test/java/io/scalecube/services/gateway/rsocket/RsocketClientConnectionTest.java index bd2a00db..ab619973 100644 --- a/services-gateway-tests/src/test/java/io/scalecube/services/gateway/rsocket/RsocketClientConnectionTest.java +++ b/services-gateway-tests/src/test/java/io/scalecube/services/gateway/rsocket/RsocketClientConnectionTest.java @@ -23,6 +23,8 @@ import io.scalecube.services.transport.rsocket.RSocketServiceTransport; import java.io.IOException; import java.time.Duration; +import java.util.Map; +import java.util.UUID; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.AfterEach; @@ -149,4 +151,29 @@ public void testHandlerEvents() throws InterruptedException { sessionEventHandler.disconnLatch.await(3, TimeUnit.SECONDS); Assertions.assertEquals(0, sessionEventHandler.disconnLatch.getCount()); } + + @Test + void testClientSettingsHeaders() { + String headerKey = "secret-token"; + String headerValue = UUID.randomUUID().toString(); + client = + new RSocketGatewayClient( + GatewayClientSettings.builder() + .headers(Map.of(headerKey, headerValue)) + .address(gatewayAddress) + .build(), + CLIENT_CODEC); + + TestService service = + new ServiceCall() + .transport(new GatewayClientTransport(client)) + .router(new StaticAddressRouter(gatewayAddress)) + .api(TestService.class); + + StepVerifier.create( + service.one("one").then(Mono.fromCallable(() -> sessionEventHandler.lastSession()))) + .assertNext(session -> assertEquals(headerValue, session.headers().get(headerKey))) + .expectComplete() + .verify(TIMEOUT); + } } diff --git a/services-gateway-tests/src/test/java/io/scalecube/services/gateway/websocket/WebsocketClientConnectionTest.java b/services-gateway-tests/src/test/java/io/scalecube/services/gateway/websocket/WebsocketClientConnectionTest.java index 4f73e450..00425b0b 100644 --- a/services-gateway-tests/src/test/java/io/scalecube/services/gateway/websocket/WebsocketClientConnectionTest.java +++ b/services-gateway-tests/src/test/java/io/scalecube/services/gateway/websocket/WebsocketClientConnectionTest.java @@ -31,6 +31,8 @@ import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.time.Duration; +import java.util.Map; +import java.util.UUID; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -200,4 +202,28 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception assertEquals(0, keepaliveLatch.getCount()); } + + @Test + void testClientSettingsHeaders() { + String headerKey = "secret-token"; + String headerValue = UUID.randomUUID().toString(); + client = + new WebsocketGatewayClient( + GatewayClientSettings.builder() + .address(gatewayAddress) + .headers(Map.of(headerKey, headerValue)) + .build(), + CLIENT_CODEC); + TestService service = + new ServiceCall() + .transport(new GatewayClientTransport(client)) + .router(new StaticAddressRouter(gatewayAddress)) + .api(TestService.class); + + StepVerifier.create( + service.one("one").then(Mono.fromCallable(() -> sessionEventHandler.lastSession()))) + .assertNext(session -> assertEquals(headerValue, session.headers().get(headerKey))) + .expectComplete() + .verify(TIMEOUT); + } }