Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forwarding port changes in 2.4 to main branch ([Backport] update delete model TransportAction to support custom model) #566

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
package org.opensearch.ml.action.models;

import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.common.MLModel.MODEL_ID_FIELD;

import lombok.AccessLevel;
import lombok.experimental.FieldDefaults;
import lombok.extern.log4j.Log4j2;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.delete.DeleteRequest;
Expand All @@ -20,15 +23,26 @@
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.index.query.TermsQueryBuilder;
import org.opensearch.index.reindex.BulkByScrollResponse;
import org.opensearch.index.reindex.DeleteByQueryAction;
import org.opensearch.index.reindex.DeleteByQueryRequest;
import org.opensearch.ml.common.transport.model.MLModelDeleteAction;
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;
import org.opensearch.rest.RestStatus;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import com.google.common.annotations.VisibleForTesting;

@Log4j2
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public class DeleteModelTransportAction extends HandledTransportAction<ActionRequest, DeleteResponse> {

static final String TIMEOUT_MSG = "Timeout while deleting model of ";
static final String BULK_FAILURE_MSG = "Bulk failure while deleting model of ";
static final String SEARCH_FAILURE_MSG = "Search failure while deleting model of ";
static final String OS_STATUS_EXCEPTION_MESSAGE = "Failed to delete all model chunks ";
Client client;

@Inject
Expand All @@ -48,13 +62,15 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
client.delete(deleteRequest, new ActionListener<DeleteResponse>() {
@Override
public void onResponse(DeleteResponse deleteResponse) {
log.info("Completed Delete Model Request, model id:{} deleted", modelId);
actionListener.onResponse(deleteResponse);
deleteModelChunks(modelId, deleteResponse, actionListener);
}

@Override
public void onFailure(Exception e) {
log.error("Failed to delete ML model " + modelId, e);
log.error("Failed to delete ML model meta Data" + modelId, e);
if (e instanceof ResourceNotFoundException) {
deleteModelChunks(modelId, null, actionListener);
}
actionListener.onFailure(e);
}
});
Expand All @@ -64,4 +80,40 @@ public void onFailure(Exception e) {
}
}

@VisibleForTesting
void deleteModelChunks(String modelId, DeleteResponse deleteResponse, ActionListener<DeleteResponse> actionListener) {
DeleteByQueryRequest deleteModelsRequest = new DeleteByQueryRequest(ML_MODEL_INDEX);
deleteModelsRequest.setQuery(new TermsQueryBuilder(MODEL_ID_FIELD, modelId));

client.execute(DeleteByQueryAction.INSTANCE, deleteModelsRequest, ActionListener.wrap(r -> {
if ((r.getBulkFailures() == null || r.getBulkFailures().size() == 0)
&& (r.getSearchFailures() == null || r.getSearchFailures().size() == 0)) {
log.info("All model chunks are deleted for model {}", modelId);
if (deleteResponse != null) {
// If model metaData not found and deleteResponse is null, do not return here.
// ResourceNotFound is returned to notify that this model was deleted.
// This is a walk around to avoid cleaning up model leftovers. Will revisit if necessary.
actionListener.onResponse(deleteResponse);
}
} else {
returnFailure(r, modelId, actionListener);
}
}, e -> {
log.info("Failed to delete ML model for " + modelId, e);
actionListener.onFailure(e);
}));
}

private void returnFailure(BulkByScrollResponse response, String modelId, ActionListener<DeleteResponse> actionListener) {
String errorMessage = "";
if (response.isTimedOut()) {
errorMessage = OS_STATUS_EXCEPTION_MESSAGE + "," + TIMEOUT_MSG + modelId;
} else if (!response.getBulkFailures().isEmpty()) {
errorMessage = OS_STATUS_EXCEPTION_MESSAGE + "," + BULK_FAILURE_MSG + modelId;
} else {
errorMessage = OS_STATUS_EXCEPTION_MESSAGE + "," + SEARCH_FAILURE_MSG + modelId;
}
log.debug(response.toString());
actionListener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.INTERNAL_SERVER_ERROR));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,15 @@
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.action.models.DeleteModelTransportAction.BULK_FAILURE_MSG;
import static org.opensearch.ml.action.models.DeleteModelTransportAction.OS_STATUS_EXCEPTION_MESSAGE;
import static org.opensearch.ml.action.models.DeleteModelTransportAction.SEARCH_FAILURE_MSG;
import static org.opensearch.ml.action.models.DeleteModelTransportAction.TIMEOUT_MSG;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;

import org.junit.Before;
import org.junit.Rule;
Expand All @@ -20,11 +27,14 @@
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.ActionListener;
import org.opensearch.action.bulk.BulkItemResponse;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.client.Client;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.index.reindex.BulkByScrollResponse;
import org.opensearch.index.reindex.ScrollableHitSource;
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
Expand All @@ -49,6 +59,9 @@ public class DeleteModelTransportActionTests extends OpenSearchTestCase {
@Mock
DeleteResponse deleteResponse;

@Mock
BulkByScrollResponse bulkByScrollResponse;

@Rule
public ExpectedException exceptionRule = ExpectedException.none();

Expand Down Expand Up @@ -76,10 +89,29 @@ public void testDeleteModel_Success() {
return null;
}).when(client).delete(any(), any());

doAnswer(invocation -> {
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null);
listener.onResponse(response);
return null;
}).when(client).execute(any(), any(), any());

deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
verify(actionListener).onResponse(deleteResponse);
}

public void testDeleteModelChunks_Success() {
when(bulkByScrollResponse.getBulkFailures()).thenReturn(null);
doAnswer(invocation -> {
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
listener.onResponse(bulkByScrollResponse);
return null;
}).when(client).execute(any(), any(), any());

deleteModelTransportAction.deleteModelChunks("test_id", deleteResponse, actionListener);
verify(actionListener).onResponse(deleteResponse);
}

public void testDeleteModel_RuntimeException() {
doAnswer(invocation -> {
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
Expand All @@ -100,4 +132,70 @@ public void testDeleteModel_ThreadContextError() {
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("thread context error", argumentCaptor.getValue().getMessage());
}

public void test_FailToDeleteModel() {
doAnswer(invocation -> {
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
listener.onFailure(new RuntimeException("errorMessage"));
return null;
}).when(client).execute(any(), any(), any());

deleteModelTransportAction.deleteModelChunks("test_id", deleteResponse, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("errorMessage", argumentCaptor.getValue().getMessage());
}

public void test_FailToDeleteAllModelChunks() {
BulkItemResponse.Failure failure = new BulkItemResponse.Failure(ML_MODEL_INDEX, "test_id", new RuntimeException("Error!"));
when(bulkByScrollResponse.getBulkFailures()).thenReturn(Arrays.asList(failure));
doAnswer(invocation -> {
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
listener.onResponse(bulkByScrollResponse);
return null;
}).when(client).execute(any(), any(), any());

deleteModelTransportAction.deleteModelChunks("test_id", deleteResponse, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(OS_STATUS_EXCEPTION_MESSAGE + "," + BULK_FAILURE_MSG + "test_id", argumentCaptor.getValue().getMessage());
}

public void test_FailToDeleteAllModelChunks_TimeOut() {
BulkItemResponse.Failure failure = new BulkItemResponse.Failure(ML_MODEL_INDEX, "test_id", new RuntimeException("Error!"));
when(bulkByScrollResponse.getBulkFailures()).thenReturn(Arrays.asList(failure));
when(bulkByScrollResponse.isTimedOut()).thenReturn(true);
doAnswer(invocation -> {
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
listener.onResponse(bulkByScrollResponse);
return null;
}).when(client).execute(any(), any(), any());

deleteModelTransportAction.deleteModelChunks("test_id", deleteResponse, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(OS_STATUS_EXCEPTION_MESSAGE + "," + TIMEOUT_MSG + "test_id", argumentCaptor.getValue().getMessage());
}

public void test_FailToDeleteAllModelChunks_SearchFailure() {
ScrollableHitSource.SearchFailure searchFailure = new ScrollableHitSource.SearchFailure(
new RuntimeException("error"),
ML_MODEL_INDEX,
123,
"node_id"
);
when(bulkByScrollResponse.getBulkFailures()).thenReturn(new ArrayList<>());
when(bulkByScrollResponse.isTimedOut()).thenReturn(false);
when(bulkByScrollResponse.getSearchFailures()).thenReturn(Arrays.asList(searchFailure));
doAnswer(invocation -> {
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
listener.onResponse(bulkByScrollResponse);
return null;
}).when(client).execute(any(), any(), any());

deleteModelTransportAction.deleteModelChunks("test_id", deleteResponse, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(OS_STATUS_EXCEPTION_MESSAGE + "," + SEARCH_FAILURE_MSG + "test_id", argumentCaptor.getValue().getMessage());
}
}