diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java index f8c1a193b6..d035fd4885 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java @@ -6,11 +6,14 @@ package org.opensearch.ml.action.models; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import static org.opensearch.ml.common.MLModel.MODEL_ID_FIELD; import lombok.AccessLevel; import lombok.experimental.FieldDefaults; import lombok.extern.log4j.Log4j2; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.ResourceNotFoundException; import org.opensearch.action.ActionListener; import org.opensearch.action.ActionRequest; import org.opensearch.action.delete.DeleteRequest; @@ -20,15 +23,26 @@ import org.opensearch.client.Client; import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.index.query.TermsQueryBuilder; +import org.opensearch.index.reindex.BulkByScrollResponse; +import org.opensearch.index.reindex.DeleteByQueryAction; +import org.opensearch.index.reindex.DeleteByQueryRequest; import org.opensearch.ml.common.transport.model.MLModelDeleteAction; import org.opensearch.ml.common.transport.model.MLModelDeleteRequest; +import org.opensearch.rest.RestStatus; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; +import com.google.common.annotations.VisibleForTesting; + @Log4j2 @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) public class DeleteModelTransportAction extends HandledTransportAction { + static final String TIMEOUT_MSG = "Timeout while deleting model of "; + static final String BULK_FAILURE_MSG = "Bulk failure while deleting model of "; + static final String SEARCH_FAILURE_MSG = "Search failure while deleting model of "; + static final String OS_STATUS_EXCEPTION_MESSAGE = "Failed to delete all model chunks "; Client client; @Inject @@ -48,13 +62,15 @@ protected void doExecute(Task task, ActionRequest request, ActionListener() { @Override public void onResponse(DeleteResponse deleteResponse) { - log.info("Completed Delete Model Request, model id:{} deleted", modelId); - actionListener.onResponse(deleteResponse); + deleteModelChunks(modelId, deleteResponse, actionListener); } @Override public void onFailure(Exception e) { - log.error("Failed to delete ML model " + modelId, e); + log.error("Failed to delete ML model meta Data" + modelId, e); + if (e instanceof ResourceNotFoundException) { + deleteModelChunks(modelId, null, actionListener); + } actionListener.onFailure(e); } }); @@ -64,4 +80,40 @@ public void onFailure(Exception e) { } } + @VisibleForTesting + void deleteModelChunks(String modelId, DeleteResponse deleteResponse, ActionListener actionListener) { + DeleteByQueryRequest deleteModelsRequest = new DeleteByQueryRequest(ML_MODEL_INDEX); + deleteModelsRequest.setQuery(new TermsQueryBuilder(MODEL_ID_FIELD, modelId)); + + client.execute(DeleteByQueryAction.INSTANCE, deleteModelsRequest, ActionListener.wrap(r -> { + if ((r.getBulkFailures() == null || r.getBulkFailures().size() == 0) + && (r.getSearchFailures() == null || r.getSearchFailures().size() == 0)) { + log.info("All model chunks are deleted for model {}", modelId); + if (deleteResponse != null) { + // If model metaData not found and deleteResponse is null, do not return here. + // ResourceNotFound is returned to notify that this model was deleted. + // This is a walk around to avoid cleaning up model leftovers. Will revisit if necessary. + actionListener.onResponse(deleteResponse); + } + } else { + returnFailure(r, modelId, actionListener); + } + }, e -> { + log.info("Failed to delete ML model for " + modelId, e); + actionListener.onFailure(e); + })); + } + + private void returnFailure(BulkByScrollResponse response, String modelId, ActionListener actionListener) { + String errorMessage = ""; + if (response.isTimedOut()) { + errorMessage = OS_STATUS_EXCEPTION_MESSAGE + "," + TIMEOUT_MSG + modelId; + } else if (!response.getBulkFailures().isEmpty()) { + errorMessage = OS_STATUS_EXCEPTION_MESSAGE + "," + BULK_FAILURE_MSG + modelId; + } else { + errorMessage = OS_STATUS_EXCEPTION_MESSAGE + "," + SEARCH_FAILURE_MSG + modelId; + } + log.debug(response.toString()); + actionListener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.INTERNAL_SERVER_ERROR)); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java index a8417341b9..3550b587d9 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java @@ -10,8 +10,15 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.action.models.DeleteModelTransportAction.BULK_FAILURE_MSG; +import static org.opensearch.ml.action.models.DeleteModelTransportAction.OS_STATUS_EXCEPTION_MESSAGE; +import static org.opensearch.ml.action.models.DeleteModelTransportAction.SEARCH_FAILURE_MSG; +import static org.opensearch.ml.action.models.DeleteModelTransportAction.TIMEOUT_MSG; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; import org.junit.Before; import org.junit.Rule; @@ -20,11 +27,14 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.action.ActionListener; +import org.opensearch.action.bulk.BulkItemResponse; import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.index.reindex.BulkByScrollResponse; +import org.opensearch.index.reindex.ScrollableHitSource; import org.opensearch.ml.common.transport.model.MLModelDeleteRequest; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -49,6 +59,9 @@ public class DeleteModelTransportActionTests extends OpenSearchTestCase { @Mock DeleteResponse deleteResponse; + @Mock + BulkByScrollResponse bulkByScrollResponse; + @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -76,10 +89,29 @@ public void testDeleteModel_Success() { return null; }).when(client).delete(any(), any()); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null); + listener.onResponse(response); + return null; + }).when(client).execute(any(), any(), any()); + deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); verify(actionListener).onResponse(deleteResponse); } + public void testDeleteModelChunks_Success() { + when(bulkByScrollResponse.getBulkFailures()).thenReturn(null); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(bulkByScrollResponse); + return null; + }).when(client).execute(any(), any(), any()); + + deleteModelTransportAction.deleteModelChunks("test_id", deleteResponse, actionListener); + verify(actionListener).onResponse(deleteResponse); + } + public void testDeleteModel_RuntimeException() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -100,4 +132,70 @@ public void testDeleteModel_ThreadContextError() { verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("thread context error", argumentCaptor.getValue().getMessage()); } + + public void test_FailToDeleteModel() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException("errorMessage")); + return null; + }).when(client).execute(any(), any(), any()); + + deleteModelTransportAction.deleteModelChunks("test_id", deleteResponse, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("errorMessage", argumentCaptor.getValue().getMessage()); + } + + public void test_FailToDeleteAllModelChunks() { + BulkItemResponse.Failure failure = new BulkItemResponse.Failure(ML_MODEL_INDEX, "test_id", new RuntimeException("Error!")); + when(bulkByScrollResponse.getBulkFailures()).thenReturn(Arrays.asList(failure)); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(bulkByScrollResponse); + return null; + }).when(client).execute(any(), any(), any()); + + deleteModelTransportAction.deleteModelChunks("test_id", deleteResponse, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals(OS_STATUS_EXCEPTION_MESSAGE + "," + BULK_FAILURE_MSG + "test_id", argumentCaptor.getValue().getMessage()); + } + + public void test_FailToDeleteAllModelChunks_TimeOut() { + BulkItemResponse.Failure failure = new BulkItemResponse.Failure(ML_MODEL_INDEX, "test_id", new RuntimeException("Error!")); + when(bulkByScrollResponse.getBulkFailures()).thenReturn(Arrays.asList(failure)); + when(bulkByScrollResponse.isTimedOut()).thenReturn(true); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(bulkByScrollResponse); + return null; + }).when(client).execute(any(), any(), any()); + + deleteModelTransportAction.deleteModelChunks("test_id", deleteResponse, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals(OS_STATUS_EXCEPTION_MESSAGE + "," + TIMEOUT_MSG + "test_id", argumentCaptor.getValue().getMessage()); + } + + public void test_FailToDeleteAllModelChunks_SearchFailure() { + ScrollableHitSource.SearchFailure searchFailure = new ScrollableHitSource.SearchFailure( + new RuntimeException("error"), + ML_MODEL_INDEX, + 123, + "node_id" + ); + when(bulkByScrollResponse.getBulkFailures()).thenReturn(new ArrayList<>()); + when(bulkByScrollResponse.isTimedOut()).thenReturn(false); + when(bulkByScrollResponse.getSearchFailures()).thenReturn(Arrays.asList(searchFailure)); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(bulkByScrollResponse); + return null; + }).when(client).execute(any(), any(), any()); + + deleteModelTransportAction.deleteModelChunks("test_id", deleteResponse, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals(OS_STATUS_EXCEPTION_MESSAGE + "," + SEARCH_FAILURE_MSG + "test_id", argumentCaptor.getValue().getMessage()); + } }