Skip to content

Commit

Permalink
Make tenant awareness setting static
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Sep 18, 2024
1 parent 66d8e2b commit 8ec61a4
Show file tree
Hide file tree
Showing 15 changed files with 96 additions and 54 deletions.
12 changes: 4 additions & 8 deletions common/src/main/java/org/opensearch/sdk/SdkClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,14 @@

import static org.opensearch.sdk.SdkClientUtils.unwrapAndConvertToException;

public class SdkClient implements SettingsChangeListener {
public class SdkClient {

private final SdkClientDelegate delegate;
private volatile Boolean isMultiTenancyEnabled;
private final Boolean isMultiTenancyEnabled;

public SdkClient(SdkClientDelegate delegate) {
public SdkClient(SdkClientDelegate delegate, Boolean multiTenancy) {
this.delegate = delegate;
}

@Override
public void onMultiTenancyEnabledChanged(boolean isEnabled) {
this.isMultiTenancyEnabled = isEnabled;
this.isMultiTenancyEnabled = multiTenancy;
}

/**
Expand Down
3 changes: 1 addition & 2 deletions common/src/test/java/org/opensearch/sdk/SdkClientTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@ public CompletionStage<SearchDataObjectResponse> searchDataObjectAsync(
return CompletableFuture.completedFuture(searchResponse);
}
});
sdkClient = new SdkClient(sdkClientImpl);
sdkClient.onMultiTenancyEnabledChanged(true);
sdkClient = new SdkClient(sdkClientImpl, true);
testException = new OpenSearchStatusException("Test", RestStatus.BAD_REQUEST);
interruptedException = new InterruptedException();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

import java.io.IOException;
import java.util.EnumSet;
import java.util.Map;
Expand Down Expand Up @@ -119,8 +120,7 @@ public class LocalClusterIndicesClientTests {
public void setup() {
MockitoAnnotations.openMocks(this);

sdkClient = new SdkClient(new LocalClusterIndicesClient(mockedClient, xContentRegistry));
sdkClient.onMultiTenancyEnabledChanged(false);
sdkClient = new SdkClient(new LocalClusterIndicesClient(mockedClient, xContentRegistry), true);

testDataObject = new TestDataObject("foo");
}
Expand Down Expand Up @@ -559,8 +559,8 @@ public void testSearchDataObjectNotTenantAware() throws IOException {
when(mockedClient.search(any(SearchRequest.class))).thenReturn(future);
when(future.actionGet()).thenReturn(searchResponse);

sdkClient.onMultiTenancyEnabledChanged(false);
SearchDataObjectResponse response = sdkClient
SdkClient sdkClientNoTenant = new SdkClient(new LocalClusterIndicesClient(mockedClient, xContentRegistry), false);
SearchDataObjectResponse response = sdkClientNoTenant
.searchDataObjectAsync(searchRequest, testThreadPool.executor(GENERAL_THREAD_POOL))
.toCompletableFuture()
.join();
Expand Down Expand Up @@ -608,7 +608,6 @@ public void testSearchDataObjectTenantAware() throws IOException {
when(mockedClient.search(any(SearchRequest.class))).thenReturn(future);
when(future.actionGet()).thenReturn(searchResponse);

sdkClient.onMultiTenancyEnabledChanged(true);
SearchDataObjectResponse response = sdkClient
.searchDataObjectAsync(searchRequest, testThreadPool.executor(GENERAL_THREAD_POOL))
.toCompletableFuture()
Expand Down Expand Up @@ -655,9 +654,7 @@ public void testSearchDataObject_Exception() throws IOException {

@Test
public void testSearchDataObject_NullTenantId() throws IOException {
// Tests exception if multitenancy enabled
sdkClient.onMultiTenancyEnabledChanged(true);

// Tests exception if multitenancy enabled
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
SearchDataObjectRequest searchRequest = SearchDataObjectRequest
.builder()
Expand All @@ -675,4 +672,26 @@ public void testSearchDataObject_NullTenantId() throws IOException {
assertEquals(OpenSearchStatusException.class, cause.getClass());
assertEquals("Tenant ID is required when multitenancy is enabled.", cause.getMessage());
}

public void testSearchDataObject_NullTenantNoMultitenancy() throws IOException {
// Tests no status exception if multitenancy not enabled
SdkClient sdkClientNoTenant = new SdkClient(new LocalClusterIndicesClient(mockedClient, xContentRegistry), false);

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
SearchDataObjectRequest searchRequest = SearchDataObjectRequest
.builder()
.indices(TEST_INDEX)
// null tenant Id
.searchSourceBuilder(searchSourceBuilder)
.build();

CompletableFuture<SearchDataObjectResponse> future = sdkClientNoTenant
.searchDataObjectAsync(searchRequest, testThreadPool.executor(GENERAL_THREAD_POOL))
.toCompletableFuture();

CompletionException ce = assertThrows(CompletionException.class, () -> future.join());
Throwable cause = ce.getCause();
assertEquals(UnsupportedOperationException.class, cause.getClass());
assertEquals("test", cause.getMessage());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ public void setup() {
when(threadPool.executor(anyString())).thenReturn(testThreadPool.executor("opensearch_ml_general"));

settings = Settings.builder().build();
sdkClient = new SdkClient(new LocalClusterIndicesClient(client, xContentRegistry));
sdkClient = new SdkClient(new LocalClusterIndicesClient(client, xContentRegistry), true);
mlAgentExecutor = Mockito
.spy(new MLAgentExecutor(client, sdkClient, settings, clusterService, xContentRegistry, toolFactories, memoryMap, false));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ public void setUp() {
MockitoAnnotations.openMocks(this);
masterKey = new ConcurrentHashMap<>();
masterKey.put(DEFAULT_TENANT_ID, "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=");
sdkClient = new SdkClient(new LocalClusterIndicesClient(client, xContentRegistry));
sdkClient = new SdkClient(new LocalClusterIndicesClient(client, xContentRegistry), true);

doAnswer(invocation -> {
ActionListener<GetResponse> listener = invocation.getArgument(1);
Expand Down
13 changes: 10 additions & 3 deletions plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ dependencies {
implementation("software.amazon.awssdk:utils:2.25.40")
// AWS OpenSearch Service dependency
implementation("software.amazon.awssdk:apache-client:2.25.40")

configurations.all {
resolutionStrategy.force 'org.apache.httpcomponents.core5:httpcore5:5.2.4'
resolutionStrategy.force 'org.apache.httpcomponents.core5:httpcore5-h2:5.2.4'
Expand All @@ -104,7 +104,8 @@ dependencies {

publishing {
publications {
pluginZip(MavenPublication) { publication ->
pluginZip(MavenPublication) {
publication ->
pom {
name = opensearchplugin.name
description = opensearchplugin.description
Expand Down Expand Up @@ -173,7 +174,9 @@ task integTest(type: RestIntegTestTask) {
testClassesDirs = sourceSets.test.output.classesDirs
classpath = sourceSets.test.runtimeClasspath
}
tasks.named("check").configure { dependsOn(integTest) }
tasks.named("check").configure {
dependsOn(integTest)
}

integTest {
dependsOn "bundlePlugin"
Expand Down Expand Up @@ -246,6 +249,10 @@ testClusters.integTest {
environment "AWS_SECRET_ACCESS_KEY", System.getenv("AWS_SECRET_ACCESS_KEY");
environment "AWS_SESSION_TOKEN", System.getenv("AWS_SESSION_TOKEN");

if (System.getProperty("tests.rest.tenantaware") != null) {
environment "plugins.ml_commons.multi_tenancy_enabled", "true"
}

testDistribution = "ARCHIVE"
// Cluster shrink exception thrown if we try to set numberOfNodes to 1, so only apply if > 1
if (_numNodes > 1) numberOfNodes = _numNodes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -613,8 +613,6 @@ public Collection<Object> createComponents(
memoryFactoryMap,
mlFeatureEnabledSetting.isMultiTenancyEnabled()
);
// Register the sdkClient as a listener
mlFeatureEnabledSetting.addListener(sdkClient);
// Register the agentExecutor as a listener
mlFeatureEnabledSetting.addListener(agentExecutor);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ public CompletionStage<SearchDataObjectResponse> searchDataObjectAsync(

private String getIndexName(String index) {
// System index is not supported in remote index. Replacing '.' from index name.
return index.replaceAll("\\.", "");
return (index.length() > 1 && index.charAt(0) == '.') ? index.substring(1) : index;
}

private XContentParser createParser(String json) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ public CompletionStage<SearchDataObjectResponse> searchDataObjectAsync(
) {
return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction<SearchDataObjectResponse>) () -> {
try {
log.info("Searching {}", Arrays.toString(request.indices()), null);
log.info("Searching {}", Arrays.toString(request.indices()));
// work around https://github.com/opensearch-project/opensearch-java/issues/1150
String json = SdkClientUtils
.lowerCaseEnumValues(
Expand All @@ -254,6 +254,8 @@ public CompletionStage<SearchDataObjectResponse> searchDataObjectAsync(
.filter(tenantIdFilterQuery.toQuery())
.build();
searchRequest = searchRequest.toBuilder().index(Arrays.asList(request.indices())).query(boolQuery.toQuery()).build();
} else {
searchRequest = searchRequest.toBuilder().index(Arrays.asList(request.indices())).build();
}
SearchResponse<?> searchResponse = openSearchClient.search(searchRequest, MAP_DOCTYPE);
log.info("Search returned {} hits", searchResponse.hits().total().value());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
*/
package org.opensearch.ml.sdkclient;

import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MULTI_TENANCY_ENABLED;
import static org.opensearch.sdk.SdkClientSettings.AWS_DYNAMO_DB;
import static org.opensearch.sdk.SdkClientSettings.AWS_OPENSEARCH_SERVICE;
import static org.opensearch.sdk.SdkClientSettings.REMOTE_METADATA_ENDPOINT;
Expand Down Expand Up @@ -84,19 +85,21 @@ public static SdkClient createSdkClient(Client client, NamedXContentRegistry xCo
String remoteMetadataEndpoint = REMOTE_METADATA_ENDPOINT.get(settings);
String region = REMOTE_METADATA_REGION.get(settings);
String serviceName = REMOTE_METADATA_SERVICE_NAME.get(settings);
Boolean multiTenancy = ML_COMMONS_MULTI_TENANCY_ENABLED.get(settings);

switch (remoteMetadataType) {
case REMOTE_OPENSEARCH:
if (Strings.isBlank(remoteMetadataEndpoint)) {
throw new OpenSearchException("Remote Opensearch client requires a metadata endpoint.");
}
log.info("Using remote opensearch cluster as metadata store");
return new SdkClient(new RemoteClusterIndicesClient(createOpenSearchClient(remoteMetadataEndpoint)));
return new SdkClient(new RemoteClusterIndicesClient(createOpenSearchClient(remoteMetadataEndpoint)), multiTenancy);
case AWS_OPENSEARCH_SERVICE:
validateAwsParams(remoteMetadataType, remoteMetadataEndpoint, region, serviceName);
log.info("Using remote AWS Opensearch Service cluster as metadata store");
return new SdkClient(
new RemoteClusterIndicesClient(createAwsOpenSearchServiceClient(remoteMetadataEndpoint, region, serviceName))
new RemoteClusterIndicesClient(createAwsOpenSearchServiceClient(remoteMetadataEndpoint, region, serviceName)),
multiTenancy
);
case AWS_DYNAMO_DB:
validateAwsParams(remoteMetadataType, remoteMetadataEndpoint, region, serviceName);
Expand All @@ -105,11 +108,12 @@ public static SdkClient createSdkClient(Client client, NamedXContentRegistry xCo
new DDBOpenSearchClient(
createDynamoDbClient(region),
new RemoteClusterIndicesClient(createAwsOpenSearchServiceClient(remoteMetadataEndpoint, region, serviceName))
)
),
multiTenancy
);
default:
log.info("Using local opensearch cluster as metadata store");
return new SdkClient(new LocalClusterIndicesClient(client, xContentRegistry));
return new SdkClient(new LocalClusterIndicesClient(client, xContentRegistry), multiTenancy);
}
}

Expand All @@ -123,8 +127,8 @@ private static void validateAwsParams(String clientType, String remoteMetadataEn
}

// Package private for testing
static SdkClient wrapSdkClientDelegate(SdkClientDelegate delegate) {
return new SdkClient(delegate);
static SdkClient wrapSdkClientDelegate(SdkClientDelegate delegate, Boolean multiTenancy) {
return new SdkClient(delegate, multiTenancy);
}

private static DynamoDbClient createDynamoDbClient(String region) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,18 @@ private MLCommonsSettings() {}
public static final Setting<Boolean> ML_COMMONS_AGENT_FRAMEWORK_ENABLED = Setting
.boolSetting("plugins.ml_commons.agent_framework_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);

// Whether multi-tenancy is enabled in ML Commons.
// This is a static setting which must be set before starting OpenSearch by (in priority order):
// 1. As a command-line argument using the -E flag (overrides other options):
// ./bin/opensearch -Eplugins.ml_commons.multi_tenancy_enabled=true
// 2. As a system property using OPENSEARCH_JAVA_OPTS (overrides opensearch.yml):
// export OPENSEARCH_JAVA_OPTS="-Dplugins.ml_commons.multi_tenancy_enabled=true"
// ./bin/opensearch
// Or inline when starting OpenSearch:
// OPENSEARCH_JAVA_OPTS="-Dplugins.ml_commons.multi_tenancy_enabled=true" ./bin/opensearch
// 3. In the opensearch.yml configuration file:
// plugins.ml_commons.multi_tenancy_enabled: true
// After setting it, a full cluster restart is required for the changes to take effect.
public static final Setting<Boolean> ML_COMMONS_MULTI_TENANCY_ENABLED = Setting
.boolSetting("plugins.ml_commons.multi_tenancy_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
.boolSetting("plugins.ml_commons.multi_tenancy_enabled", false, Setting.Property.NodeScope);
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,6 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_AGENT_FRAMEWORK_ENABLED, it -> isAgentFrameworkEnabled = it);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_LOCAL_MODEL_ENABLED, it -> isLocalModelEnabled = it);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MULTI_TENANCY_ENABLED, it -> {
isMultiTenancyEnabled = it;
notifyMultiTenancyListeners(it);
});
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,16 +176,6 @@ public void setupSettings() throws IOException {
response = TestHelper
.makeRequest(client(), "PUT", "_cluster/settings", ImmutableMap.of(), TestHelper.toHttpEntity(jsonEntity), null);
assertEquals(200, response.getStatusLine().getStatusCode());

String multiTenancyEntity = "{\n"
+ " \"persistent\" : {\n"
+ " \"plugins.ml_commons.multi_tenancy_enabled\" : false \n"
+ " }\n"
+ "}";

response = TestHelper
.makeRequest(client(), "PUT", "_cluster/settings", ImmutableMap.of(), TestHelper.toHttpEntity(multiTenancyEntity), null);
assertEquals(200, response.getStatusLine().getStatusCode());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ public static void cleanup() {
public void setup() {
MockitoAnnotations.openMocks(this);

sdkClient = SdkClientFactory.wrapSdkClientDelegate(new DDBOpenSearchClient(dynamoDbClient, remoteClusterIndicesClient));
sdkClient.onMultiTenancyEnabledChanged(true);
sdkClient = SdkClientFactory.wrapSdkClientDelegate(new DDBOpenSearchClient(dynamoDbClient, remoteClusterIndicesClient), true);
testDataObject = new TestDataObject("foo");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,7 @@ public void setup() {
.setSerializationInclusion(JsonInclude.Include.NON_NULL)
)
);
sdkClient = SdkClientFactory.wrapSdkClientDelegate(new RemoteClusterIndicesClient(mockedOpenSearchClient));
sdkClient.onMultiTenancyEnabledChanged(true);
sdkClient = SdkClientFactory.wrapSdkClientDelegate(new RemoteClusterIndicesClient(mockedOpenSearchClient), true);
testDataObject = new TestDataObject("foo");
}

Expand Down Expand Up @@ -592,8 +591,6 @@ public void testSearchDataObject_Exception() throws IOException {

public void testSearchDataObject_NullTenant() throws IOException {
// Tests exception if multitenancy enabled
sdkClient.onMultiTenancyEnabledChanged(true);

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
SearchDataObjectRequest searchRequest = SearchDataObjectRequest
.builder()
Expand All @@ -612,4 +609,27 @@ public void testSearchDataObject_NullTenant() throws IOException {
assertEquals(OpenSearchStatusException.class, cause.getClass());
assertEquals("Tenant ID is required when multitenancy is enabled.", cause.getMessage());
}

public void testSearchDataObject_NullTenantNoMultitenancy() throws IOException {
// Tests no status exception if multitenancy not enabled
SdkClient sdkClientNoTenant = SdkClientFactory.wrapSdkClientDelegate(new RemoteClusterIndicesClient(mockedOpenSearchClient), false);

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
SearchDataObjectRequest searchRequest = SearchDataObjectRequest
.builder()
.indices(TEST_INDEX)
// null tenant Id
.searchSourceBuilder(searchSourceBuilder)
.build();

when(mockedOpenSearchClient.search(any(SearchRequest.class), any())).thenThrow(new UnsupportedOperationException("test"));
CompletableFuture<SearchDataObjectResponse> future = sdkClientNoTenant
.searchDataObjectAsync(searchRequest, testThreadPool.executor(GENERAL_THREAD_POOL))
.toCompletableFuture();

CompletionException ce = assertThrows(CompletionException.class, () -> future.join());
Throwable cause = ce.getCause();
assertEquals(UnsupportedOperationException.class, cause.getClass());
assertEquals("test", cause.getMessage());
}
}

0 comments on commit 8ec61a4

Please sign in to comment.