From 99306dc0da2b8159774125b70772b5a1a35f28f0 Mon Sep 17 00:00:00 2001 From: Sicheng Song <114637679+b4sjoo@users.noreply.github.com> Date: Wed, 30 Nov 2022 15:22:55 -0800 Subject: [PATCH] add test coverage to transportUploadModelAction (#511) (#593) --- plugin/build.gradle | 1 - .../TransportUploadModelActionTests.java | 205 ++++++++++++++++++ 2 files changed, 205 insertions(+), 1 deletion(-) create mode 100644 plugin/src/test/java/org/opensearch/ml/action/upload/TransportUploadModelActionTests.java diff --git a/plugin/build.gradle b/plugin/build.gradle index 32dbd277f5..dafc63463f 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -242,7 +242,6 @@ List jacocoExclusions = [ 'org.opensearch.ml.model.MLModelManager', 'org.opensearch.ml.action.unload.TransportUnloadModelAction', 'org.opensearch.ml.action.forward.TransportForwardAction', - 'org.opensearch.ml.action.upload.TransportUploadModelAction', 'org.opensearch.ml.rest.RestMLPredictionAction', 'org.opensearch.ml.rest.RestMLUploadModelAction', 'org.opensearch.ml.rest.RestMLUnloadModelAction', diff --git a/plugin/src/test/java/org/opensearch/ml/action/upload/TransportUploadModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/upload/TransportUploadModelActionTests.java new file mode 100644 index 0000000000..ba20a52510 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/upload/TransportUploadModelActionTests.java @@ -0,0 +1,205 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.upload; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; +import static org.mockito.Mockito.when; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionListenerResponseHandler; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.ml.cluster.DiscoveryNodeHelper; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; +import org.opensearch.ml.common.transport.forward.MLForwardResponse; +import org.opensearch.ml.common.transport.upload.MLUploadInput; +import org.opensearch.ml.common.transport.upload.MLUploadModelRequest; +import org.opensearch.ml.common.transport.upload.UploadModelResponse; +import org.opensearch.ml.engine.ModelHelper; +import org.opensearch.ml.indices.MLIndicesHandler; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.stats.MLNodeLevelStat; +import org.opensearch.ml.stats.MLStat; +import org.opensearch.ml.stats.MLStats; +import org.opensearch.ml.task.MLTaskDispatcher; +import org.opensearch.ml.task.MLTaskManager; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class TransportUploadModelActionTests extends OpenSearchTestCase { + + @Mock + private TransportService transportService; + + @Mock + private ModelHelper modelHelper; + + @Mock + private MLIndicesHandler mlIndicesHandler; + + @Mock + private MLModelManager mlModelManager; + + @Mock + private MLTaskManager mlTaskManager; + + @Mock + private ClusterService clusterService; + + @Mock + private ThreadPool threadPool; + + @Mock + private Client client; + + @Mock + private DiscoveryNodeHelper nodeFilter; + + @Mock + private MLTaskDispatcher mlTaskDispatcher; + + @Mock + private MLStats mlStats; + + @Mock + private ActionFilters actionFilters; + + @Mock + private Task task; + + @Mock + private ActionListener actionListener; + + @Mock + private DiscoveryNode node1; + + @Mock + private DiscoveryNode node2; + + @Mock + private IndexResponse indexResponse; + + private TransportUploadModelAction transportUploadModelAction; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + + transportUploadModelAction = new TransportUploadModelAction( + transportService, + actionFilters, + modelHelper, + mlIndicesHandler, + mlModelManager, + mlTaskManager, + clusterService, + threadPool, + client, + nodeFilter, + mlTaskDispatcher, + mlStats + ); + assertNotNull(transportUploadModelAction); + + MLStat mlStat = mock(MLStat.class); + when(mlStats.getStat(eq(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT))).thenReturn(mlStat); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(indexResponse); + return null; + }).when(mlTaskManager).createMLTask(any(), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(node1); + return null; + }).when(mlTaskDispatcher).dispatch(any()); + + when(clusterService.localNode()).thenReturn(node2); + + doAnswer(invocation -> { return null; }).when(mlModelManager).uploadMLModel(any(), any()); + + } + + public void testDoExecute_successWithLocalNodeEqualToClusterNode() { + when(node1.getId()).thenReturn("NodeId1"); + when(node2.getId()).thenReturn("NodeId1"); + + transportUploadModelAction.doExecute(task, prepareRequest(), actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UploadModelResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + } + + public void testDoExecute_successWithLocalNodeNotEqualToClusterNode() { + when(node1.getId()).thenReturn("NodeId1"); + when(node2.getId()).thenReturn("NodeId2"); + MLForwardResponse forwardResponse = Mockito.mock(MLForwardResponse.class); + doAnswer(invocation -> { + ActionListenerResponseHandler handler = invocation.getArgument(3); + handler.handleResponse(forwardResponse); + return null; + }).when(transportService).sendRequest(any(), any(), any(), any()); + + transportUploadModelAction.doExecute(task, prepareRequest(), actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UploadModelResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + } + + public void testTransportUploadModelActionDoExecuteWithDispatchException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onFailure(new Exception("Failed to dispatch upload model task ")); + return null; + }).when(mlTaskDispatcher).dispatch(any()); + transportUploadModelAction.doExecute(task, prepareRequest(), actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + } + + public void testTransportUploadModelActionDoExecuteWithCreateTaskException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new Exception("Failed to create upload model task")); + return null; + }).when(mlTaskManager).createMLTask(any(), any()); + + transportUploadModelAction.doExecute(task, prepareRequest(), actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + } + + private MLUploadModelRequest prepareRequest() { + MLUploadInput uploadInput = MLUploadInput + .builder() + .functionName(FunctionName.BATCH_RCF) + .loadModel(true) + .version("1.0") + .modelName("Test Model") + .modelConfig( + new TextEmbeddingModelConfig("CUSTOM", 123, TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS, "all config") + ) + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .url("Test URL") + .build(); + return new MLUploadModelRequest(uploadInput); + } + +}