Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed routing map provider NPE #42874

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.azure.cosmos.models.PartitionKey;
import com.azure.cosmos.rx.TestSuiteBase;
import com.azure.cosmos.rx.proxy.HttpProxyServer;
import io.netty.channel.ChannelOption;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;
import org.testng.annotations.AfterClass;
Expand All @@ -34,6 +35,7 @@

import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
Expand Down Expand Up @@ -126,7 +128,7 @@ public void operationsList(CosmosClient cosmosClient) throws Exception {

ClientTelemetry clientTelemetry = cosmosClient.asyncClient().getContextClient().getClientTelemetry();
setClientTelemetrySchedulingInSec(clientTelemetry, 5);
clientTelemetry.init();
clientTelemetry.init().subscribe();

InternalObjectNode internalObjectNode = getInternalObjectNode();
cosmosContainer.createItem(internalObjectNode); //create operation
Expand Down Expand Up @@ -190,7 +192,7 @@ public void operationsListWithNoTelemetry() throws Exception {
"clientTelemetrySchedulingSec");
backgroundRefreshLocationTimeIntervalInMSField.setAccessible(true);
backgroundRefreshLocationTimeIntervalInMSField.setInt(clientTelemetry, 5);
clientTelemetry.init();
clientTelemetry.init().subscribe();

InternalObjectNode internalObjectNode = getInternalObjectNode();
cosmosContainer.createItem(internalObjectNode); // create operation
Expand Down Expand Up @@ -251,10 +253,60 @@ public void httpClientTests(CosmosClient cosmosClient) throws Exception {
AtomicReference<AzureVMMetadata> vmMetadata = ReflectionUtils.getAzureVMMetadata(clientTelemetry);
vmMetadata.set(null);

clientTelemetry.init();
clientTelemetry.init().subscribe();
assertThat(clientTelemetryMetadataHttpClientWrapper.capturedRequests.size()).isEqualTo(1);
}


@Test(groups = {"emulator"}, dataProvider = "clients", timeOut = TIMEOUT)
public void shouldDisableIMDSAccess(CosmosClient cosmosClient) throws Exception {
// Test using different http client for client telemetry requests and metaRequests

System.setProperty("COSMOS.DISABLE_IMDS_ACCESS", "true");

ClientTelemetry clientTelemetry = cosmosClient.asyncClient().getContextClient().getClientTelemetry();
HttpClient clientTelemetryHttpClient = ReflectionUtils.getClientTelemetryMetadataHttpClient(clientTelemetry);
HttpClient clientTelemetryMetadataHttpClient = ReflectionUtils.getClientTelemetryHttpClint(clientTelemetry);

assertThat(clientTelemetryHttpClient).isNotSameAs(clientTelemetryMetadataHttpClient);

// Test metadataHttpClient is used for IMDS requests
HttpClientUnderTestWrapper clientTelemetryMetadataHttpClientWrapper = new HttpClientUnderTestWrapper(clientTelemetryHttpClient);
ReflectionUtils.setClientTelemetryMetadataHttpClient(clientTelemetry, clientTelemetryMetadataHttpClientWrapper.getSpyHttpClient());
AtomicReference<AzureVMMetadata> vmMetadata = ReflectionUtils.getAzureVMMetadata(clientTelemetry);
vmMetadata.set(null);

clientTelemetry.init().subscribe();
// Call should not go through loading azure VM metadata
assertThat(clientTelemetryMetadataHttpClientWrapper.capturedRequests.size()).isEqualTo(0);

System.setProperty("COSMOS.DISABLE_IMDS_ACCESS", "false");// setting it back for other tests
}


@Test(groups = {"emulator"}, dataProvider = "clients", timeOut = TIMEOUT)
public void httpClientsConfigurationTests(CosmosClient cosmosClient) throws Exception {
// Test using different http client for client telemetry requests and metaRequests
ClientTelemetry clientTelemetry = cosmosClient.asyncClient().getContextClient().getClientTelemetry();
HttpClient clientTelemetryHttpClient = ReflectionUtils.getClientTelemetryMetadataHttpClient(clientTelemetry);
HttpClient clientTelemetryMetadataHttpClient = ReflectionUtils.getClientTelemetryHttpClint(clientTelemetry);

assertThat(clientTelemetryHttpClient).isNotSameAs(clientTelemetryMetadataHttpClient);

reactor.netty.http.client.HttpClient reactorHttpClient =
ReflectionUtils.get(reactor.netty.http.client.HttpClient.class, clientTelemetryMetadataHttpClient,
"httpClient");

Duration responseTimeout = reactorHttpClient.configuration().responseTimeout();
int maxConnections = reactorHttpClient.configuration().connectionProvider().maxConnections();
Integer connectionAcquireTimeout = (Integer) reactorHttpClient.configuration().options().get(ChannelOption.CONNECT_TIMEOUT_MILLIS);

assertThat(responseTimeout).isEqualTo(ClientTelemetry.IMDS_DEFAULT_NETWORK_REQUEST_TIMEOUT);
assertThat(maxConnections).isEqualTo(ClientTelemetry.IMDS_DEFAULT_MAX_CONNECTION_POOL_SIZE);
assertThat(connectionAcquireTimeout).isEqualTo((int) ClientTelemetry.IMDS_DEFAULT_CONNECTION_ACQUIRE_TIMEOUT.toMillis());
}


@Test(groups = {"unit"})
public void clientTelemetryScheduling() {
assertThat(Configs.getClientTelemetrySchedulingInSec()).isEqualTo(600);
Expand Down Expand Up @@ -304,7 +356,7 @@ public void clientTelemetryWithStageJunoEndpoint(boolean useProxy) throws Interr
cosmosClient.getDatabase(databaseId).getContainer(containerId);
ClientTelemetry clientTelemetry = cosmosClient.asyncClient().getContextClient().getClientTelemetry();
setClientTelemetrySchedulingInSec(clientTelemetry, 5);
clientTelemetry.init();
clientTelemetry.init().subscribe();

// If this test need to run on local machine please add below env property,
// in test env we add the env property with cosmos-client-telemetry-endpoint variable in tests.yml,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ public void rntbd() throws Exception {
assertThat(objectNode.get("numberOfClients").asInt()).isEqualTo(2);
assertThat(objectNode.get("consistencyCfg").asText()).isEqualTo("(consistency: null, mm: false, prgns: [null])");
assertThat(objectNode.get("connCfg").get("rntbd").asText()).isEqualTo("(cto:PT5S, nrto:PT5S, icto:PT0S, ieto:PT1H, mcpe:130, mrpc:30, cer:true)");
assertThat(objectNode.get("connCfg").get("gw").asText()).isEqualTo("(cps:null, nrto:null, icto:null, p:false)");
assertThat(objectNode.get("connCfg").get("gw").asText()).isEqualTo("(cps:1000, nrto:PT1M, icto:PT1M, cto:PT45S, p:false)");
assertThat(objectNode.get("connCfg").get("other").asText()).isEqualTo("(ed: false, cs: false, rv: true)");
}

Expand Down Expand Up @@ -237,7 +237,7 @@ public void gw() throws Exception {
assertThat(objectNode.get("numberOfClients").asInt()).isEqualTo(2);
assertThat(objectNode.get("consistencyCfg").asText()).isEqualTo("(consistency: null, mm: false, prgns: [null])");
assertThat(objectNode.get("connCfg").get("rntbd").asText()).isEqualTo("null");
assertThat(objectNode.get("connCfg").get("gw").asText()).isEqualTo("(cps:500, nrto:PT18S, icto:PT17S, p:false)");
assertThat(objectNode.get("connCfg").get("gw").asText()).isEqualTo("(cps:500, nrto:PT18S, icto:PT17S, cto:PT45S, p:false)");
assertThat(objectNode.get("connCfg").get("other").asText()).isEqualTo("(ed: false, cs: false, rv: true)");
}

Expand Down Expand Up @@ -309,7 +309,7 @@ public void full(
assertThat(objectNode.get("numberOfClients").asInt()).isEqualTo(2);
assertThat(objectNode.get("consistencyCfg").asText()).isEqualTo("(consistency: null, mm: false, prgns: [westus1,westus2])");
assertThat(objectNode.get("connCfg").get("rntbd").asText()).isEqualTo("null");
assertThat(objectNode.get("connCfg").get("gw").asText()).isEqualTo("(cps:500, nrto:PT18S, icto:PT17S, p:false)");
assertThat(objectNode.get("connCfg").get("gw").asText()).isEqualTo("(cps:500, nrto:PT18S, icto:PT17S, cto:PT45S, p:false)");
assertThat(objectNode.get("connCfg").get("other").asText()).isEqualTo("(ed: true, cs: true, rv: false)");
assertThat(objectNode.get("excrgns").asText()).isEqualTo("[westus2]");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import com.azure.core.credential.AzureKeyCredential;
import com.azure.cosmos.ConnectionMode;
import com.azure.cosmos.ConsistencyLevel;
import com.azure.cosmos.implementation.circuitBreaker.GlobalPartitionEndpointManagerForCircuitBreaker;
import com.azure.cosmos.implementation.directconnectivity.Protocol;
import com.azure.cosmos.implementation.directconnectivity.ReflectionUtils;
import com.azure.cosmos.implementation.http.HttpClient;
Expand Down Expand Up @@ -199,15 +198,11 @@ public static class ClientUnderTest extends SpyBaseClass<HttpRequest> {

private Mono<HttpResponse> captureHttpRequest(InvocationOnMock invocationOnMock) {
HttpRequest httpRequest = invocationOnMock.getArgument(0, HttpRequest.class);
Duration responseTimeout = Duration.ofSeconds(Configs.getHttpResponseTimeoutInSeconds());
if (invocationOnMock.getArguments().length == 2) {
responseTimeout = invocationOnMock.getArgument(1, Duration.class);
}
CompletableFuture<HttpResponse> f = new CompletableFuture<>();
this.requestsResponsePairs.add(Pair.of(httpRequest, f));

return origHttpClient
.send(httpRequest, responseTimeout)
.send(httpRequest)
.doOnNext(httpResponse -> f.complete(httpResponse.buffer()))
.doOnError(f::completeExceptionally);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,16 @@ private HttpClientMockWrapper(long responseAfterMillis, final HttpResponse httpR
return Mono.delay(Duration.ofMillis(responseAfterMillis)).flatMap(t -> httpResponseOrException(httpResponse, e));
}
}).when(httpClient).send(Mockito.any(HttpRequest.class), Mockito.any(Duration.class));

Mockito.doAnswer(invocationOnMock -> {
HttpRequest httpRequest = invocationOnMock.getArgument(0, HttpRequest.class);
requests.add(httpRequest);
if (responseAfterMillis <= 0) {
return httpResponseOrException(httpResponse, e);
} else {
return Mono.delay(Duration.ofMillis(responseAfterMillis)).flatMap(t -> httpResponseOrException(httpResponse, e));
}
}).when(httpClient).send(Mockito.any(HttpRequest.class));
}

public HttpClientMockWrapper(HttpClientBehaviourBuilder builder) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ public HttpTransportClientUnderTest(ConnectionPolicy connectionPolicy, UserAgent
}

@Override
HttpClient createHttpClient(ConnectionPolicy connectionPolicy) {
HttpClient createHttpClient(Configs configs, ConnectionPolicy connectionPolicy) {
return httpClient;
}
}
Expand Down Expand Up @@ -141,6 +141,7 @@ public void validateDefaultHeaders() {
RxDocumentServiceRequest request = RxDocumentServiceRequest.createFromName(mockDiagnosticsClientContext(),
OperationType.Create, "dbs/db/colls/col", ResourceType.Document);
request.setContentBytes(new byte[0]);
request.setResponseTimeout(connectionPolicy.getHttpNetworkRequestTimeout());

transportClient.invokeResourceOperationAsync(Uri.create(physicalAddress), request).block();

Expand Down Expand Up @@ -460,6 +461,7 @@ public void failuresWithHttpStatusCodes(HttpClientMockWrapper.HttpClientBehaviou
httpClientMockWrapper.getClient());
RxDocumentServiceRequest request = RxDocumentServiceRequest.createFromName(mockDiagnosticsClientContext(),
OperationType.Create, "dbs/db/colls/col", ResourceType.Document);
request.setResponseTimeout(connectionPolicy.getHttpNetworkRequestTimeout());
request.setContentBytes(new byte[0]);
request.requestContext.resourcePhysicalAddress = "dbs/db/colls/col";

Expand Down Expand Up @@ -568,6 +570,7 @@ public void networkFailures(RxDocumentServiceRequest request,
UserAgentContainer userAgentContainer = new UserAgentContainer();
ConnectionPolicy connectionPolicy = ConnectionPolicy.getDefaultPolicy();
connectionPolicy.setHttpNetworkRequestTimeout(Duration.ofSeconds(100));
request.setResponseTimeout(connectionPolicy.getHttpNetworkRequestTimeout());
HttpTransportClient transportClient = getHttpTransportClientUnderTest(
connectionPolicy,
userAgentContainer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,22 @@
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

import java.time.Duration;

import static org.assertj.core.api.Assertions.assertThat;

/**
* Tests that partition manager correctly resolves addresses for requests and does appropriate number of cache refreshes.
* Tests that partition manager correctly resolves addresses for requests and does appropriate number of cache
* refreshes.
*/
public class ReactorNettyHttpClientTest {

private static final Logger logger = LoggerFactory.getLogger(ReactorNettyHttpClientTest.class);
private Configs configs;
private HttpClient reactorNettyHttpClient;

@BeforeClass(groups = "unit")
public void before_ReactorNettyHttpClientTest() {
this.configs = new Configs();
this.reactorNettyHttpClient = HttpClient.createFixed(new HttpClientConfig(this.configs));
this.reactorNettyHttpClient = HttpClient.createFixed(new HttpClientConfig(new Configs()));
}

@AfterClass(groups = "unit")
Expand All @@ -40,21 +41,21 @@ public void after_ReactorNettyHttpClientTest() {
public void httpClientWithMaxHeaderSize() {
reactor.netty.http.client.HttpClient httpClient =
ReflectionUtils.get(reactor.netty.http.client.HttpClient.class, this.reactorNettyHttpClient, "httpClient");
assertThat(httpClient.configuration().decoder().maxHeaderSize()).isEqualTo(this.configs.getMaxHttpHeaderSize());
assertThat(httpClient.configuration().decoder().maxHeaderSize()).isEqualTo(Configs.getMaxHttpHeaderSize());
}

@Test(groups = "unit")
public void httpClientWithMaxChunkSize() {
reactor.netty.http.client.HttpClient httpClient =
ReflectionUtils.get(reactor.netty.http.client.HttpClient.class, this.reactorNettyHttpClient, "httpClient");
assertThat(httpClient.configuration().decoder().maxChunkSize()).isEqualTo(this.configs.getMaxHttpChunkSize());
assertThat(httpClient.configuration().decoder().maxChunkSize()).isEqualTo(Configs.getMaxHttpChunkSize());
}

@Test(groups = "unit")
public void httpClientWithMaxInitialLineLength() {
reactor.netty.http.client.HttpClient httpClient =
ReflectionUtils.get(reactor.netty.http.client.HttpClient.class, this.reactorNettyHttpClient, "httpClient");
assertThat(httpClient.configuration().decoder().maxInitialLineLength()).isEqualTo(this.configs.getMaxHttpInitialLineLength());
assertThat(httpClient.configuration().decoder().maxInitialLineLength()).isEqualTo(Configs.getMaxHttpInitialLineLength());
}

@Test(groups = "unit")
Expand All @@ -65,10 +66,36 @@ public void httpClientWithValidateHeaders() {
}

@Test(groups = "unit")
public void httpClientWithOptions() {
public void httpClientWithConnectionAcquireTimeout() {
reactor.netty.http.client.HttpClient httpClient =
ReflectionUtils.get(reactor.netty.http.client.HttpClient.class, this.reactorNettyHttpClient, "httpClient");
Integer connectionTimeoutInMillis =
(Integer) httpClient.configuration().options().get(ChannelOption.CONNECT_TIMEOUT_MILLIS);
assertThat(connectionTimeoutInMillis).isEqualTo((int) Configs.getConnectionAcquireTimeout().toMillis());
}

@Test(groups = "unit")
public void httpClientWithMaxPoolSize() {
reactor.netty.http.client.HttpClient httpClient =
ReflectionUtils.get(reactor.netty.http.client.HttpClient.class, this.reactorNettyHttpClient, "httpClient");
int maxConnectionPoolSize = httpClient.configuration().connectionProvider().maxConnections();
assertThat(maxConnectionPoolSize).isEqualTo(Configs.getDefaultHttpPoolSize());
}

@Test(groups = "unit")
// We don't set any default response timeout to http client
public void httpClientWithResponseTimeout() {
reactor.netty.http.client.HttpClient httpClient =
ReflectionUtils.get(reactor.netty.http.client.HttpClient.class, this.reactorNettyHttpClient, "httpClient");
Duration responseTimeout = httpClient.configuration().responseTimeout();
assertThat(responseTimeout).isNull();
}

@Test(groups = "unit")
public void httpClientWithConnectionProviderName() {
reactor.netty.http.client.HttpClient httpClient =
ReflectionUtils.get(reactor.netty.http.client.HttpClient.class, this.reactorNettyHttpClient, "httpClient");
Integer connectionTimeoutInMillis = (Integer) httpClient.configuration().options().get(ChannelOption.CONNECT_TIMEOUT_MILLIS);
assertThat(connectionTimeoutInMillis).isEqualTo((int) this.configs.getConnectionAcquireTimeout().toMillis());
String name = httpClient.configuration().connectionProvider().name();
assertThat(name).isEqualTo(Configs.getReactorNettyConnectionPoolName());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import com.azure.cosmos.implementation.IRoutingMapProvider;
import com.azure.cosmos.implementation.MetadataDiagnosticsContext;
import com.azure.cosmos.implementation.PartitionKeyRange;
import com.azure.cosmos.implementation.Resource;
import com.azure.cosmos.implementation.Utils;
import com.azure.cosmos.implementation.apachecommons.lang.tuple.ImmutablePair;
import org.apache.commons.lang3.StringUtils;
Expand Down Expand Up @@ -90,6 +91,18 @@ public Mono<Utils.ValueHolder<PartitionKeyRange>> tryGetPartitionKeyRangeByIdAsy
}
}

private class MockIRoutingMapProviderWithNullRoutingMap extends MockIRoutingMapProvider {

public MockIRoutingMapProviderWithNullRoutingMap(List<PartitionKeyRange> ranges) {
super(ranges);
}

@Override
public Mono<Utils.ValueHolder<List<PartitionKeyRange>>> tryGetOverlappingRangesAsync(MetadataDiagnosticsContext metaDataDiagnosticsContext, String collectionResourceId, Range<String> range, boolean forceRefresh, Map<String, Object> properties) {
return Mono.just(new Utils.ValueHolder<>(null));
}
}


@Test(groups = { "unit" }, expectedExceptions = IllegalArgumentException.class)
public void nonSortedRanges() {
Expand Down Expand Up @@ -213,4 +226,21 @@ public String apply(PartitionKeyRange range) {
assertThat(2).isEqualTo(overLappingRangeList.size());
assertThat("0,1").isEqualTo(overLappingRangeList.stream().map(func).collect(Collectors.joining(",")));
}

@Test(groups = {"unit"})
// This test is to verify that the NPE has been fixed in RoutingMapProviderHelper.getOverlappingRanges
public void getOverlappingRangesWithoutOverlapping() {

Function<PartitionKeyRange, String> func = Resource::getId;

List<PartitionKeyRange> rangeList = Arrays.asList(new PartitionKeyRange("0", "", "FF"));

IRoutingMapProvider routingMapProviderMock = new MockIRoutingMapProviderWithNullRoutingMap(rangeList);

Mono<List<PartitionKeyRange>> overlappingRanges;
overlappingRanges = RoutingMapProviderHelper.getOverlappingRanges(routingMapProviderMock,
"coll1",
Arrays.asList(new Range<String>("", "FF", true, false)));
assertThat("").isEqualTo(overlappingRanges.block().stream().map(func).collect(Collectors.joining(",")));
}
}
Loading
Loading