Skip to content

Commit

Permalink
[Feature/multi_tenancy] Make model group access control checks tenant…
Browse files Browse the repository at this point in the history
… aware (#2867)

* Add more tenant-aware checks for model groups

Signed-off-by: Daniel Widdis <[email protected]>

* Make model group access control checks tenant aware

Signed-off-by: Daniel Widdis <[email protected]>

---------

Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis authored Sep 5, 2024
1 parent d51ec4f commit 1096e47
Show file tree
Hide file tree
Showing 11 changed files with 244 additions and 191 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.ml.common.CommonValue;
import org.opensearch.sdk.DeleteDataObjectRequest;
import org.opensearch.sdk.DeleteDataObjectResponse;
Expand Down Expand Up @@ -216,18 +217,26 @@ public CompletionStage<SearchDataObjectResponse> searchDataObjectAsync(
Executor executor,
Boolean isMultiTenancyEnabled
) {
log.info("Searching {}", Arrays.toString(request.indices()));
SearchSourceBuilder searchSource = request.searchSourceBuilder();
if (Boolean.TRUE.equals(isMultiTenancyEnabled)) {
if (request.tenantId() == null) {
return CompletableFuture.failedFuture(
new OpenSearchStatusException("Tenant ID is required when multitenancy is enabled.", RestStatus.BAD_REQUEST)
);
return CompletableFuture
.failedFuture(
new OpenSearchStatusException("Tenant ID is required when multitenancy is enabled.", RestStatus.BAD_REQUEST)
);
}
QueryBuilder existingQuery = searchSource.query();
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery().must(existingQuery == null ? new MatchAllQueryBuilder() : existingQuery);
boolQuery.filter(QueryBuilders.termQuery(CommonValue.TENANT_ID, request.tenantId()));
searchSource.query(boolQuery);
TermQueryBuilder tenantIdTermQuery = QueryBuilders.termQuery(CommonValue.TENANT_ID, request.tenantId());
if (existingQuery == null) {
searchSource.query(tenantIdTermQuery);
} else {
BoolQueryBuilder boolQuery = existingQuery instanceof BoolQueryBuilder
? (BoolQueryBuilder) existingQuery
: QueryBuilders.boolQuery().must(existingQuery);
boolQuery.filter(tenantIdTermQuery);
searchSource.query(boolQuery);
}
log.debug("Adding tenant id to search query", Arrays.toString(request.indices()));
}
log.info("Searching {}", Arrays.toString(request.indices()));
ActionFuture<SearchResponse> searchResponseFuture = AccessController
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -618,8 +618,7 @@ public void testSearchDataObjectTenantAware() throws IOException {
verify(mockedClient, times(1)).search(requestCaptor.capture());
assertEquals(1, requestCaptor.getValue().indices().length);
assertEquals(TEST_INDEX, requestCaptor.getValue().indices()[0]);
assertTrue(requestCaptor.getValue().source().toString().contains("\"query\":{\"bool\":{\"must\":"));
assertTrue(requestCaptor.getValue().source().toString().contains("\"filter\":[{\"term\":{\"tenant_id\":{\"value\":\"xyz\""));
assertTrue(requestCaptor.getValue().source().toString().contains("{\"term\":{\"tenant_id\":{\"value\":\"xyz\""));

SearchResponse searchActionResponse = SearchResponse.fromXContent(response.parser());
assertEquals(0, searchActionResponse.getFailedShards());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,61 +96,71 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
User user = RestActionUtils.getUserContext(client);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<DeleteResponse> wrappedListener = ActionListener.runBefore(actionListener, context::restore);
modelAccessControlHelper.validateModelGroupAccess(user, modelGroupId, client, ActionListener.wrap(access -> {
if (!access) {
wrappedListener.onFailure(new MLValidationException("User doesn't have privilege to delete this model group"));
} else {
BoolQueryBuilder query = new BoolQueryBuilder();
query.filter(new TermQueryBuilder(PARAMETER_MODEL_GROUP_ID, modelGroupId));

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query);
SearchRequest searchRequest = new SearchRequest(ML_MODEL_INDEX).source(searchSourceBuilder);

SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest
.builder()
.indices(ML_MODEL_INDEX)
.searchSourceBuilder(searchSourceBuilder)
.build();

sdkClient
.searchDataObjectAsync(searchDataObjectRequest, client.threadPool().executor(GENERAL_THREAD_POOL))
.whenComplete((sr, st) -> {
if (sr != null) {
try {
SearchResponse searchResponse = SearchResponse.fromXContent(sr.parser());
SearchHit[] searchHits = searchResponse.getHits().getHits();
if (searchHits.length == 0) {
deleteModelGroup(deleteRequest, tenantId, wrappedListener);
modelAccessControlHelper
.validateModelGroupAccess(
user,
mlFeatureEnabledSetting,
tenantId,
modelGroupId,
client,
sdkClient,
ActionListener.wrap(access -> {
if (!access) {
wrappedListener.onFailure(new MLValidationException("User doesn't have privilege to delete this model group"));
} else {
BoolQueryBuilder query = new BoolQueryBuilder();
query.filter(new TermQueryBuilder(PARAMETER_MODEL_GROUP_ID, modelGroupId));

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query);
SearchRequest searchRequest = new SearchRequest(ML_MODEL_INDEX).source(searchSourceBuilder);

SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest
.builder()
.indices(ML_MODEL_INDEX)
.tenantId(tenantId)
.searchSourceBuilder(searchSourceBuilder)
.build();

sdkClient
.searchDataObjectAsync(searchDataObjectRequest, client.threadPool().executor(GENERAL_THREAD_POOL))
.whenComplete((sr, st) -> {
if (sr != null) {
try {
SearchResponse searchResponse = SearchResponse.fromXContent(sr.parser());
SearchHit[] searchHits = searchResponse.getHits().getHits();
if (searchHits.length == 0) {
deleteModelGroup(deleteRequest, tenantId, wrappedListener);
} else {
actionListener
.onFailure(
new OpenSearchStatusException(
"Cannot delete the model group when it has associated model versions",
RestStatus.CONFLICT
)
);
}
} catch (Exception e) {
log.error("Failed to parse search response", e);
actionListener
.onFailure(
new OpenSearchStatusException(
"Failed to parse search response",
RestStatus.INTERNAL_SERVER_ERROR
)
);
}
} else {
actionListener
.onFailure(
new OpenSearchStatusException(
"Cannot delete the model group when it has associated model versions",
RestStatus.CONFLICT
)
);
Exception cause = SdkClientUtils.unwrapAndConvertToException(st);
handleModelSearchFailure(modelGroupId, tenantId, cause, actionListener);
}
} catch (Exception e) {
log.error("Failed to parse search response", e);
actionListener
.onFailure(
new OpenSearchStatusException(
"Failed to parse search response",
RestStatus.INTERNAL_SERVER_ERROR
)
);
}
} else {
Exception cause = SdkClientUtils.unwrapAndConvertToException(st);
handleModelSearchFailure(modelGroupId, tenantId, cause, actionListener);
}
});

}
}, e -> {
log.error("Failed to validate Access for Model Group {}", modelGroupId, e);
wrappedListener.onFailure(e);
}));
});

}
}, e -> {
log.error("Failed to validate Access for Model Group {}", modelGroupId, e);
wrappedListener.onFailure(e);
})
);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Update
modelAccessControlHelper
.validateModelGroupAccess(
user,
mlFeatureEnabledSetting,
tenantId,
mlModel.getModelGroupId(),
client,
sdkClient,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,48 +201,58 @@ private void checkUserAccess(
) {
User user = RestActionUtils.getUserContext(client);
modelAccessControlHelper
.validateModelGroupAccess(user, registerModelInput.getModelGroupId(), client, sdkClient, ActionListener.wrap(access -> {
if (access) {
doRegister(registerModelInput, listener);
return;
}
// if the user does not have access, we need to check three more conditions before throwing exception.
// if we are checking the access based on the name provided in the input, we let user know the name is already used by a
// model group they do not have access to.
if (isModelNameAlreadyExisting) {
// This case handles when user is using the same pre-trained model already registered by another user on the cluster.
// The only way here is for the user to first create model group and use its ID in the request
if (registerModelInput.getUrl() == null
&& registerModelInput.getFunctionName() != FunctionName.REMOTE
&& registerModelInput.getConnectorId() == null) {
listener
.onFailure(
new IllegalArgumentException(
"Without a model group ID, the system will use the model name {"
+ registerModelInput.getModelName()
+ "} to create a new model group. However, this name is taken by another group with id {"
+ registerModelInput.getModelGroupId()
+ "} you can't access. To register this pre-trained model, create a new model group and use its ID in your request."
)
);
} else {
listener
.onFailure(
new IllegalArgumentException(
"The name {"
+ registerModelInput.getModelName()
+ "} you provided is unavailable because it is used by another model group with id {"
+ registerModelInput.getModelGroupId()
+ "} to which you do not have access. Please provide a different name."
)
);
.validateModelGroupAccess(
user,
mlFeatureEnabledSetting,
registerModelInput.getTenantId(),
registerModelInput.getModelGroupId(),
client,
sdkClient,
ActionListener.wrap(access -> {
if (access) {
doRegister(registerModelInput, listener);
return;
}
return;
}
// if user does not have access to the model group ID provided in the input, we let user know they do not have access to the
// specified model group
listener.onFailure(new IllegalArgumentException("You don't have permissions to perform this operation on this model."));
}, listener::onFailure));
// if the user does not have access, we need to check three more conditions before throwing exception.
// if we are checking the access based on the name provided in the input, we let user know the name is already used by a
// model group they do not have access to.
if (isModelNameAlreadyExisting) {
// This case handles when user is using the same pre-trained model already registered by another user on the
// cluster.
// The only way here is for the user to first create model group and use its ID in the request
if (registerModelInput.getUrl() == null
&& registerModelInput.getFunctionName() != FunctionName.REMOTE
&& registerModelInput.getConnectorId() == null) {
listener
.onFailure(
new IllegalArgumentException(
"Without a model group ID, the system will use the model name {"
+ registerModelInput.getModelName()
+ "} to create a new model group. However, this name is taken by another group with id {"
+ registerModelInput.getModelGroupId()
+ "} you can't access. To register this pre-trained model, create a new model group and use its ID in your request."
)
);
} else {
listener
.onFailure(
new IllegalArgumentException(
"The name {"
+ registerModelInput.getModelName()
+ "} you provided is unavailable because it is used by another model group with id {"
+ registerModelInput.getModelGroupId()
+ "} to which you do not have access. Please provide a different name."
)
);
}
return;
}
// if user does not have access to the model group ID provided in the input, we let user know they do not have access to
// the
// specified model group
listener.onFailure(new IllegalArgumentException("You don't have permissions to perform this operation on this model."));
}, listener::onFailure)
);
}

private void doRegister(MLRegisterModelInput registerModelInput, ActionListener<MLRegisterModelResponse> listener) {
Expand Down
Loading

0 comments on commit 1096e47

Please sign in to comment.