Skip to content

Commit

Permalink
Added externalHosts config setting, refactored RSocketClientTranspo…
Browse files Browse the repository at this point in the history
…rt to support connection on multiple addresses
  • Loading branch information
artem-v committed Sep 2, 2023
1 parent db83884 commit 7f57054
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
Expand All @@ -23,7 +24,7 @@ public class ServiceEndpoint implements Externalizable {
private static final long serialVersionUID = 1L;

private String id;
private Address address;
private List<Address> addresses;
private Set<String> contentTypes;
private Map<String, String> tags;
private Collection<ServiceRegistration> serviceRegistrations;
Expand All @@ -38,7 +39,8 @@ public ServiceEndpoint() {}

private ServiceEndpoint(Builder builder) {
this.id = Objects.requireNonNull(builder.id, "ServiceEndpoint.id is required");
this.address = Objects.requireNonNull(builder.address, "ServiceEndpoint.address is required");
this.addresses =
Objects.requireNonNull(builder.addresses, "ServiceEndpoint.addresses is required");
this.contentTypes = Collections.unmodifiableSet(new HashSet<>(builder.contentTypes));
this.tags = Collections.unmodifiableMap(new HashMap<>(builder.tags));
this.serviceRegistrations =
Expand All @@ -57,8 +59,8 @@ public String id() {
return id;
}

public Address address() {
return address;
public List<Address> addresses() {
return addresses;
}

public Set<String> contentTypes() {
Expand Down Expand Up @@ -88,7 +90,7 @@ public Collection<ServiceReference> serviceReferences() {
public String toString() {
return new StringJoiner(", ", ServiceEndpoint.class.getSimpleName() + "[", "]")
.add("id=" + id)
.add("address=" + address)
.add("addresses=" + addresses)
.add("contentTypes=" + contentTypes)
.add("tags=" + tags)
.add("serviceRegistrations(" + serviceRegistrations.size() + ")")
Expand All @@ -100,8 +102,11 @@ public void writeExternal(ObjectOutput out) throws IOException {
// id
out.writeUTF(id);

// address
out.writeUTF(address.toString());
// addresses
out.writeInt(addresses.size());
for (Address address : addresses) {
out.writeUTF(address.toString());
}

// contentTypes
out.writeInt(contentTypes.size());
Expand All @@ -128,8 +133,12 @@ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundExcept
// id
id = in.readUTF();

// address
address = Address.from(in.readUTF());
// addresses
final int addressesSize = in.readInt();
addresses = new ArrayList<>(addressesSize);
for (int i = 0; i < addressesSize; i++) {
addresses.add(Address.from(in.readUTF()));
}

// contentTypes
int contentTypesSize = in.readInt();
Expand Down Expand Up @@ -161,7 +170,7 @@ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundExcept
public static class Builder {

private String id;
private Address address = Address.NULL_ADDRESS;
private List<Address> addresses = new ArrayList<>();
private Set<String> contentTypes = Collections.emptySet();
private Map<String, String> tags = Collections.emptyMap();
private Collection<ServiceRegistration> serviceRegistrations = new ArrayList<>();
Expand All @@ -170,7 +179,7 @@ private Builder() {}

private Builder(ServiceEndpoint other) {
this.id = other.id;
this.address = other.address;
this.addresses = new ArrayList<>(other.addresses);
this.contentTypes = new HashSet<>(other.contentTypes);
this.tags = new HashMap<>(other.tags);
this.serviceRegistrations = new ArrayList<>(other.serviceRegistrations);
Expand All @@ -181,8 +190,12 @@ public Builder id(String id) {
return this;
}

public Builder address(Address address) {
this.address = Objects.requireNonNull(address, "address");
public Builder addresses(Address... addresses) {
return addresses(Arrays.asList(addresses));
}

public Builder addresses(List<Address> addresses) {
this.addresses = new ArrayList<>(Objects.requireNonNull(addresses, "addresses"));
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import io.scalecube.net.Address;
import io.scalecube.services.api.Qualifier;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.StringJoiner;
Expand All @@ -20,7 +22,7 @@ public class ServiceReference {
private final Set<String> contentTypes;
private final Map<String, String> tags;
private final String action;
private final Address address;
private final List<Address> addresses;
private final boolean isSecured;

/**
Expand All @@ -40,7 +42,7 @@ public ServiceReference(
this.tags = mergeTags(serviceMethodDefinition, serviceRegistration, serviceEndpoint);
this.action = serviceMethodDefinition.action();
this.qualifier = Qualifier.asString(namespace, action);
this.address = serviceEndpoint.address();
this.addresses = new ArrayList<>(serviceEndpoint.addresses());
this.isSecured = serviceMethodDefinition.isSecured();
}

Expand Down Expand Up @@ -72,8 +74,8 @@ public String action() {
return action;
}

public Address address() {
return this.address;
public List<Address> addresses() {
return addresses;
}

public boolean isSecured() {
Expand All @@ -95,7 +97,7 @@ private Map<String, String> mergeTags(
public String toString() {
return new StringJoiner(", ", ServiceReference.class.getSimpleName() + "[", "]")
.add("endpointId=" + endpointId)
.add("address=" + address)
.add("addresses=" + addresses)
.add("qualifier=" + qualifier)
.add("contentTypes=" + contentTypes)
.add("tags=" + tags)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public StaticAddressRouter(Address address) {
new ServiceMethodDefinition(UUID.randomUUID().toString()),
new ServiceRegistration(
UUID.randomUUID().toString(), Collections.emptyMap(), Collections.emptyList()),
ServiceEndpoint.builder().id(UUID.randomUUID().toString()).address(address).build());
ServiceEndpoint.builder().id(UUID.randomUUID().toString()).addresses(address).build());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,23 @@
import io.scalecube.services.transport.api.ClientChannel;
import java.lang.reflect.Type;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.netty.channel.AbortedException;

public class RSocketClientChannel implements ClientChannel {

private static final Logger LOGGER = LoggerFactory.getLogger(RSocketClientChannel.class);

private final Mono<RSocket> rsocket;
private final Mono<RSocket> promise;
private final ServiceMessageCodec messageCodec;

public RSocketClientChannel(Mono<RSocket> rsocket, ServiceMessageCodec codec) {
this.rsocket = rsocket;
public RSocketClientChannel(Mono<RSocket> promise, ServiceMessageCodec codec) {
this.promise = promise;
this.messageCodec = codec;
}

@Override
public Mono<ServiceMessage> requestResponse(ServiceMessage message, Type responseType) {
return rsocket
return promise
.flatMap(rsocket -> rsocket.requestResponse(toPayload(message)))
.map(this::toMessage)
.map(msg -> ServiceMessageCodec.decodeData(msg, responseType))
Expand All @@ -37,7 +33,7 @@ public Mono<ServiceMessage> requestResponse(ServiceMessage message, Type respons

@Override
public Flux<ServiceMessage> requestStream(ServiceMessage message, Type responseType) {
return rsocket
return promise
.flatMapMany(rsocket -> rsocket.requestStream(toPayload(message)))
.map(this::toMessage)
.map(msg -> ServiceMessageCodec.decodeData(msg, responseType))
Expand All @@ -47,7 +43,7 @@ public Flux<ServiceMessage> requestStream(ServiceMessage message, Type responseT
@Override
public Flux<ServiceMessage> requestChannel(
Publisher<ServiceMessage> publisher, Type responseType) {
return rsocket
return promise
.flatMapMany(rsocket -> rsocket.requestChannel(Flux.from(publisher).map(this::toPayload)))
.map(this::toMessage)
.map(msg -> ServiceMessageCodec.decodeData(msg, responseType))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
import io.scalecube.utils.MaskUtil;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono;
Expand All @@ -31,7 +33,7 @@ public class RSocketClientTransport implements ClientTransport {

private static final Logger LOGGER = LoggerFactory.getLogger(RSocketClientTransport.class);

private final ThreadLocal<Map<Address, Mono<RSocket>>> rsockets =
private final ThreadLocal<Map<String, Mono<RSocket>>> connections =
ThreadLocal.withInitial(ConcurrentHashMap::new);

private final CredentialsSupplier credentialsSupplier;
Expand Down Expand Up @@ -64,17 +66,72 @@ public RSocketClientTransport(

@Override
public ClientChannel create(ServiceReference serviceReference) {
final Map<Address, Mono<RSocket>> monoMap = rsockets.get(); // keep reference for threadsafety
final Address address = serviceReference.address();
Mono<RSocket> mono =
monoMap.computeIfAbsent(
address,
final String endpointId = serviceReference.endpointId();
final Map<String, Mono<RSocket>> connections = this.connections.get();

Mono<RSocket> promise =
connections.computeIfAbsent(
endpointId,
key ->
getCredentials(serviceReference)
.flatMap(creds -> connect(key, creds, monoMap))
connect(serviceReference, connections)
.cache()
.doOnError(ex -> monoMap.remove(key)));
return new RSocketClientChannel(mono, new ServiceMessageCodec(headersCodec, dataCodecs));
.doOnError(ex -> connections.remove(key)));

return new RSocketClientChannel(promise, new ServiceMessageCodec(headersCodec, dataCodecs));
}

private Mono<RSocket> connect(
ServiceReference serviceReference, Map<String, Mono<RSocket>> connections) {
return Mono.defer(
() -> {
final String endpointId = serviceReference.endpointId();
final List<Address> addresses = serviceReference.addresses();
final AtomicInteger currentIndex = new AtomicInteger(0);

return Mono.defer(
() -> {
final Address address = addresses.get(currentIndex.get());
return connect(serviceReference, connections, address, endpointId);
})
.doOnError(ex -> currentIndex.incrementAndGet())
.retry(addresses.size() - 1)
.doOnError(
th ->
LOGGER.warn(
"Failed to connect ({}/{}), cause: {}",
endpointId,
addresses,
th.toString()));
});
}

private Mono<RSocket> connect(
ServiceReference serviceReference,
Map<String, Mono<RSocket>> connections,
Address address,
String endpointId) {
return getCredentials(serviceReference)
.flatMap(
creds ->
RSocketConnector.create()
.payloadDecoder(PayloadDecoder.DEFAULT)
.setupPayload(encodeConnectionSetup(new ConnectionSetup(creds)))
.connect(() -> clientTransportFactory.clientTransport(address)))
.doOnSuccess(
rsocket -> {
LOGGER.debug("[{}] Connected successfully", address);
// Setup shutdown hook
rsocket
.onClose()
.doFinally(
s -> {
connections.remove(endpointId);
LOGGER.debug("[{}] Connection closed", address);
})
.doOnError(
th -> LOGGER.warn("[{}] Exception on close: {}", address, th.toString()))
.subscribe();
});
}

private Mono<Map<String, String>> getCredentials(ServiceReference serviceReference) {
Expand Down Expand Up @@ -103,37 +160,6 @@ private Mono<Map<String, String>> getCredentials(ServiceReference serviceReferen
});
}

private Mono<RSocket> connect(
Address address, Map<String, String> creds, Map<Address, Mono<RSocket>> monoMap) {
return RSocketConnector.create()
.payloadDecoder(PayloadDecoder.DEFAULT)
.setupPayload(encodeConnectionSetup(new ConnectionSetup(creds)))
.connect(() -> clientTransportFactory.clientTransport(address))
.doOnSuccess(
rsocket -> {
LOGGER.debug("[rsocket][client][{}] Connected successfully", address);
// setup shutdown hook
rsocket
.onClose()
.doFinally(
s -> {
monoMap.remove(address);
LOGGER.debug("[rsocket][client][{}] Connection closed", address);
})
.doOnError(
th ->
LOGGER.warn(
"[rsocket][client][{}][onClose] Exception occurred: {}",
address,
th.toString()))
.subscribe();
})
.doOnError(
th ->
LOGGER.warn(
"[rsocket][client][{}] Failed to connect, cause: {}", address, th.toString()));
}

private Payload encodeConnectionSetup(ConnectionSetup connectionSetup) {
ByteBuf byteBuf = ByteBufAllocator.DEFAULT.buffer();
try {
Expand Down
Loading

0 comments on commit 7f57054

Please sign in to comment.