Skip to content

Commit

Permalink
Merge pull request #322 from dcherednik/ip_discovery
Browse files Browse the repository at this point in the history
Support for using ip address in discovery response.
  • Loading branch information
alex268 authored Sep 27, 2024
2 parents 14c04e1 + ce9981a commit a133f3d
Show file tree
Hide file tree
Showing 14 changed files with 71 additions and 36 deletions.
19 changes: 18 additions & 1 deletion core/src/main/java/tech/ydb/core/impl/YdbDiscovery.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import tech.ydb.core.operation.OperationBinder;
import tech.ydb.core.utils.FutureTools;
import tech.ydb.proto.discovery.DiscoveryProtos;
import tech.ydb.proto.discovery.DiscoveryProtos.EndpointInfo;
import tech.ydb.proto.discovery.v1.DiscoveryServiceGrpc;

/**
Expand Down Expand Up @@ -185,6 +186,21 @@ private void handleOk(String selfLocation, List<EndpointRecord> endpoints) {
}
}

private static String createAddress(EndpointInfo e) {
String addr;
if (e.getIpV6Count() > 0 && e.getIpV6(0) != null && !e.getIpV6(0).isEmpty()) {
addr = e.getIpV6(0);
} else if (e.getIpV4Count() > 0 && e.getIpV4(0) != null && !e.getIpV4(0).isEmpty()) {
addr = e.getIpV4(0);
} else {
addr = e.getAddress();
}

logger.debug("address {} will be used to connect to node {}", addr, e.getAddress());

return addr;
}

private void handleDiscoveryResult(Result<DiscoveryProtos.ListEndpointsResult> response, Throwable th) {
if (th != null) {
Throwable cause = FutureTools.unwrapCompletionException(th);
Expand All @@ -202,7 +218,8 @@ private void handleDiscoveryResult(Result<DiscoveryProtos.ListEndpointsResult> r
}

List<EndpointRecord> records = result.getEndpointsList().stream()
.map(e -> new EndpointRecord(e.getAddress(), e.getPort(), e.getNodeId(), e.getLocation()))
.map(e -> new EndpointRecord(createAddress(e), e.getPort(), e.getNodeId(), e.getLocation(),
e.getSslTargetNameOverride()))
.collect(Collectors.toList());

logger.debug("successfully received ListEndpoints result with {} endpoints", records.size());
Expand Down
17 changes: 14 additions & 3 deletions core/src/main/java/tech/ydb/core/impl/pool/EndpointRecord.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,35 @@ public class EndpointRecord {
private final String host;
private final String hostAndPort;
private final String locationDC;
private final String authority;
private final int port;
private final int nodeId;

public EndpointRecord(String host, int port, int nodeId, String locationDC) {
public EndpointRecord(String host, int port, int nodeId, String locationDC, String authority) {
this.host = Objects.requireNonNull(host);
this.port = port;
this.hostAndPort = host + ":" + port;
this.nodeId = nodeId;
this.locationDC = locationDC;
if (authority != null && !authority.isEmpty()) {
this.authority = authority;
} else {
this.authority = null;
}
}

public EndpointRecord(String host, int port) {
this(host, port, 0, null);
this(host, port, 0, null, null);
}

public String getHost() {
return host;
}

public String getAuthority() {
return authority;
}

public int getPort() {
return port;
}
Expand All @@ -46,6 +56,7 @@ public String getLocation() {

@Override
public String toString() {
return "Endpoint{host=" + host + ", port=" + port + ", node=" + nodeId + ", location=" + locationDC + "}";
return "Endpoint{host=" + host + ", port=" + port + ", node=" + nodeId +
", location=" + locationDC + ", overrideAuthority=" + authority + "}";
}
}
3 changes: 2 additions & 1 deletion core/src/main/java/tech/ydb/core/impl/pool/GrpcChannel.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ public GrpcChannel(EndpointRecord endpoint, ManagedChannelFactory factory) {
try {
logger.debug("Creating grpc channel with {}", endpoint);
this.endpoint = endpoint;
this.channel = factory.newManagedChannel(endpoint.getHost(), endpoint.getPort());
this.channel = factory.newManagedChannel(endpoint.getHost(), endpoint.getPort(),
endpoint.getAuthority());
this.connectTimeoutMs = factory.getConnectTimeoutMs();
this.readyWatcher = new ReadyWatcher();
this.readyWatcher.checkState();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ interface Builder {
ManagedChannelFactory buildFactory(GrpcTransportBuilder builder);
}

ManagedChannel newManagedChannel(String host, int port);
ManagedChannel newManagedChannel(String host, int port, String authority);

long getConnectTimeoutMs();
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,17 @@ public long getConnectTimeoutMs() {

@SuppressWarnings("deprecation")
@Override
public ManagedChannel newManagedChannel(String host, int port) {
public ManagedChannel newManagedChannel(String host, int port, String sslHostOverride) {
NettyChannelBuilder channelBuilder = NettyChannelBuilder
.forAddress(host, port);

if (useTLS) {
channelBuilder
.negotiationType(NegotiationType.TLS)
.sslContext(createSslContext());
if (sslHostOverride != null) {
channelBuilder.overrideAuthority(sslHostOverride);
}
} else {
channelBuilder.negotiationType(NegotiationType.PLAINTEXT);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,17 @@ public long getConnectTimeoutMs() {

@SuppressWarnings("deprecation")
@Override
public ManagedChannel newManagedChannel(String host, int port) {
public ManagedChannel newManagedChannel(String host, int port, String sslHostOverride) {
NettyChannelBuilder channelBuilder = NettyChannelBuilder
.forAddress(host, port);

if (useTLS) {
channelBuilder
.negotiationType(NegotiationType.TLS)
.sslContext(createSslContext());
if (sslHostOverride != null) {
channelBuilder.overrideAuthority(sslHostOverride);
}
} else {
channelBuilder.negotiationType(NegotiationType.PLAINTEXT);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public void setUp() throws InterruptedException {
Mockito.when(channel.shutdownNow()).thenReturn(channel);
Mockito.when(channel.awaitTermination(Mockito.anyLong(), Mockito.any())).thenReturn(true);

Mockito.when(channelFactory.newManagedChannel(Mockito.any(), Mockito.anyInt())).thenReturn(channel);
Mockito.when(channelFactory.newManagedChannel(Mockito.any(), Mockito.anyInt(), Mockito.isNull())).thenReturn(channel);
}

private <T extends Throwable> T checkFutureException(CompletableFuture<Boolean> f, String message, Class<T> clazz) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ public void setUp() throws InterruptedException {
Mockito.when(transportChannel.shutdownNow()).thenReturn(transportChannel);
Mockito.when(transportChannel.awaitTermination(Mockito.anyLong(), Mockito.any())).thenReturn(true);

Mockito.when(channelFactory.newManagedChannel(Mockito.eq("mocked"), Mockito.eq(2136)))
Mockito.when(channelFactory.newManagedChannel(Mockito.eq("mocked"), Mockito.eq(2136), Mockito.isNull()))
.thenReturn(discoveryChannel);
Mockito.when(channelFactory.newManagedChannel(Mockito.eq("node"), Mockito.eq(2136)))
Mockito.when(channelFactory.newManagedChannel(Mockito.eq("node"), Mockito.eq(2136), Mockito.isNull()))
.thenReturn(transportChannel);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public void defaultParams() {
channelStaticMock.verify(FOR_ADDRESS, times(0));

Assert.assertEquals(30_000l, factory.getConnectTimeoutMs());
Assert.assertSame(channelMock, factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT));
Assert.assertSame(channelMock, factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT, null));

channelStaticMock.verify(FOR_ADDRESS, times(1));

Expand All @@ -100,7 +100,7 @@ public void defaultSslFactory() {
channelStaticMock.verify(FOR_ADDRESS, times(0));

Assert.assertEquals(60000l, factory.getConnectTimeoutMs());
Assert.assertSame(channelMock, factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT));
Assert.assertSame(channelMock, factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT, null));

channelStaticMock.verify(FOR_ADDRESS, times(1));

Expand All @@ -124,7 +124,7 @@ public void customChannelInitializer() {

channelStaticMock.verify(FOR_ADDRESS, times(0));

Assert.assertSame(channelMock, factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT));
Assert.assertSame(channelMock, factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT, null));

channelStaticMock.verify(FOR_ADDRESS, times(1));

Expand All @@ -150,7 +150,7 @@ public void customSslFactory() throws CertificateException, IOException {
ManagedChannelFactory factory = ChannelFactoryLoader.load().buildFactory(builder);

Assert.assertEquals(4000l, factory.getConnectTimeoutMs());
Assert.assertSame(channelMock, factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT));
Assert.assertSame(channelMock, factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT, null));

} finally {
selfSignedCert.delete();
Expand All @@ -176,7 +176,7 @@ public void invalidSslCert() {
ManagedChannelFactory factory = ChannelFactoryLoader.load().buildFactory(builder);

RuntimeException ex = Assert.assertThrows(RuntimeException.class,
() -> factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT));
() -> factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT, null));

Assert.assertEquals("cannot create ssl context", ex.getMessage());
Assert.assertNotNull(ex.getCause());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@ public void nodePessimizationTest() {
check(pool.getEndpoint(2)).hostname("n2.ydb.tech").nodeID(2).port(12342);

// Pessimize unknown nodes - nothing is changed
pool.pessimizeEndpoint(new EndpointRecord("n2.ydb.tech", 12341, 2, null));
pool.pessimizeEndpoint(new EndpointRecord("n2.ydb.tech", 12342, 2, null));
pool.pessimizeEndpoint(new EndpointRecord("n2.ydb.tech", 12341, 2, null, null));
pool.pessimizeEndpoint(new EndpointRecord("n2.ydb.tech", 12342, 2, null, null));
pool.pessimizeEndpoint(null);
check(pool).records(5).knownNodes(5).needToReDiscovery(false).bestEndpointsCount(4);

Expand Down Expand Up @@ -553,6 +553,6 @@ private static List<EndpointRecord> list(EndpointRecord... records) {
}

private static EndpointRecord endpoint(int nodeID, String hostname, int port, String location) {
return new EndpointRecord(hostname, port, nodeID, location);
return new EndpointRecord(hostname, port, nodeID, location, null);
}
}
20 changes: 10 additions & 10 deletions core/src/test/java/tech/ydb/core/impl/pool/GrpcChannelPoolTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public class GrpcChannelPoolTest {
@Before
public void setUp() {
Mockito.when(factoryMock.getConnectTimeoutMs()).thenReturn(500l); // timeout for ready watcher
Mockito.when(factoryMock.newManagedChannel(Mockito.any(), Mockito.anyInt()))
Mockito.when(factoryMock.newManagedChannel(Mockito.any(), Mockito.anyInt(), Mockito.isNull()))
.then((args) -> ManagedChannelMock.good());
}

Expand All @@ -34,8 +34,8 @@ public void tearDown() throws Exception {

@Test
public void simpleTest() {
EndpointRecord e1 = new EndpointRecord("host1", 1234, 10, null);
EndpointRecord e2 = new EndpointRecord("host1", 1235, 11, null);
EndpointRecord e1 = new EndpointRecord("host1", 1234, 10, null, null);
EndpointRecord e2 = new EndpointRecord("host1", 1235, 11, null, null);

GrpcChannelPool pool = new GrpcChannelPool(factoryMock, scheduler);
Assert.assertEquals(0, pool.getChannels().size());
Expand Down Expand Up @@ -66,9 +66,9 @@ public void simpleTest() {

@Test
public void removeChannels() {
EndpointRecord e1 = new EndpointRecord("host1", 1234, 10, null);
EndpointRecord e2 = new EndpointRecord("host1", 1235, 11, null);
EndpointRecord e3 = new EndpointRecord("host1", 1236, 12, null);
EndpointRecord e1 = new EndpointRecord("host1", 1234, 10, null, null);
EndpointRecord e2 = new EndpointRecord("host1", 1235, 11, null, null);
EndpointRecord e3 = new EndpointRecord("host1", 1236, 12, null, null);

GrpcChannelPool pool = new GrpcChannelPool(factoryMock, scheduler);
Assert.assertEquals(0, pool.getChannels().size());
Expand Down Expand Up @@ -121,13 +121,13 @@ public void removeChannels() {

@Test
public void badShutdownTest() {
Mockito.when(factoryMock.newManagedChannel(Mockito.any(), Mockito.anyInt())).thenReturn(
Mockito.when(factoryMock.newManagedChannel(Mockito.any(), Mockito.anyInt(), Mockito.isNull())).thenReturn(
ManagedChannelMock.good(), ManagedChannelMock.good(),
ManagedChannelMock.wrongShutdown(), ManagedChannelMock.wrongShutdown());

EndpointRecord e1 = new EndpointRecord("host1", 1234, 10, null);
EndpointRecord e2 = new EndpointRecord("host1", 1235, 11, null);
EndpointRecord e3 = new EndpointRecord("host1", 1236, 12, null);
EndpointRecord e1 = new EndpointRecord("host1", 1234, 10, null, null);
EndpointRecord e2 = new EndpointRecord("host1", 1235, 11, null, null);
EndpointRecord e3 = new EndpointRecord("host1", 1236, 12, null, null);

GrpcChannelPool pool = new GrpcChannelPool(factoryMock, scheduler);
Assert.assertEquals(0, pool.getChannels().size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public void setUp() {

@Test
public void goodChannels() {
Mockito.when(factoryMock.newManagedChannel(Mockito.any(), Mockito.anyInt()))
Mockito.when(factoryMock.newManagedChannel(Mockito.any(), Mockito.anyInt(), Mockito.isNull()))
.thenReturn(ManagedChannelMock.good(), ManagedChannelMock.good());

EndpointRecord endpoint = new EndpointRecord("host1", 1234);
Expand All @@ -52,7 +52,7 @@ public void slowChannels() {
ConnectivityState.READY,
};

Mockito.when(factoryMock.newManagedChannel(Mockito.any(), Mockito.anyInt()))
Mockito.when(factoryMock.newManagedChannel(Mockito.any(), Mockito.anyInt(), Mockito.isNull()))
.thenReturn(new ManagedChannelMock(ConnectivityState.IDLE).nextStates(states))
.thenReturn(new ManagedChannelMock(ConnectivityState.IDLE).nextStates(states));

Expand All @@ -74,7 +74,7 @@ public void badChannels() {
ConnectivityState.SHUTDOWN,
};

Mockito.when(factoryMock.newManagedChannel(Mockito.any(), Mockito.anyInt()))
Mockito.when(factoryMock.newManagedChannel(Mockito.any(), Mockito.anyInt(), Mockito.isNull()))
.thenReturn(new ManagedChannelMock(ConnectivityState.IDLE).nextStates(states))
.thenReturn(new ManagedChannelMock(ConnectivityState.IDLE).nextStates(states));

Expand All @@ -84,7 +84,7 @@ public void badChannels() {
Assert.assertEquals(endpoint, channel.getEndpoint());

RuntimeException ex1 = Assert.assertThrows(RuntimeException.class, channel::getReadyChannel);
Assert.assertEquals("Channel Endpoint{host=host1, port=1234, node=0, location=null} connecting problem",
Assert.assertEquals("Channel Endpoint{host=host1, port=1234, node=0, location=null, overrideAuthority=null} connecting problem",
ex1.getMessage());

channel.shutdown();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedE

public static ManagedChannelFactory.Builder MOCKED = (GrpcTransportBuilder builder) -> new ManagedChannelFactory() {
@Override
public ManagedChannel newManagedChannel(String host, int port) {
public ManagedChannel newManagedChannel(String host, int port, String authority) {
return good();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public void fixedLocalDcTest() {

@Test
public void detectLocalDCfallbackTest() {
List<EndpointRecord> single = Collections.singletonList(new EndpointRecord("localhost", 8080, 0, "DC1"));
List<EndpointRecord> single = Collections.singletonList(new EndpointRecord("localhost", 8080, 0, "DC1", null));
PriorityPicker ignoreSelftLocation = PriorityPicker.from(BalancingSettings.detectLocalDs(), "DC1", single);

Assert.assertEquals(0, ignoreSelftLocation.getEndpointPriority("DC1"));
Expand All @@ -73,7 +73,7 @@ public void detectLocalDCTest() {
final int port = serverSocket.getLocalPort();

List<EndpointRecord> records = Arrays.asList("DC1", "DC1", "DC2", "DC2", "DC2", "DC3")
.stream().map(location -> new EndpointRecord("localhost", port, 1, location))
.stream().map(location -> new EndpointRecord("localhost", port, 1, location, null))
.collect(Collectors.toList());

String localDC = PriorityPicker.detectLocalDC(records, testTicker);
Expand Down

0 comments on commit a133f3d

Please sign in to comment.