Skip to content

Commit

Permalink
add test coverage to transportUploadModelAction (#511) (#593)
Browse files Browse the repository at this point in the history
  • Loading branch information
b4sjoo authored Nov 30, 2022
1 parent 527cf66 commit 99306dc
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 1 deletion.
1 change: 0 additions & 1 deletion plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@ List<String> jacocoExclusions = [
'org.opensearch.ml.model.MLModelManager',
'org.opensearch.ml.action.unload.TransportUnloadModelAction',
'org.opensearch.ml.action.forward.TransportForwardAction',
'org.opensearch.ml.action.upload.TransportUploadModelAction',
'org.opensearch.ml.rest.RestMLPredictionAction',
'org.opensearch.ml.rest.RestMLUploadModelAction',
'org.opensearch.ml.rest.RestMLUnloadModelAction',
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.action.upload;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.*;
import static org.mockito.Mockito.when;

import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionListenerResponseHandler;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.client.Client;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.transport.forward.MLForwardResponse;
import org.opensearch.ml.common.transport.upload.MLUploadInput;
import org.opensearch.ml.common.transport.upload.MLUploadModelRequest;
import org.opensearch.ml.common.transport.upload.UploadModelResponse;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.indices.MLIndicesHandler;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.stats.MLStat;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.task.MLTaskDispatcher;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.tasks.Task;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

public class TransportUploadModelActionTests extends OpenSearchTestCase {

@Mock
private TransportService transportService;

@Mock
private ModelHelper modelHelper;

@Mock
private MLIndicesHandler mlIndicesHandler;

@Mock
private MLModelManager mlModelManager;

@Mock
private MLTaskManager mlTaskManager;

@Mock
private ClusterService clusterService;

@Mock
private ThreadPool threadPool;

@Mock
private Client client;

@Mock
private DiscoveryNodeHelper nodeFilter;

@Mock
private MLTaskDispatcher mlTaskDispatcher;

@Mock
private MLStats mlStats;

@Mock
private ActionFilters actionFilters;

@Mock
private Task task;

@Mock
private ActionListener<UploadModelResponse> actionListener;

@Mock
private DiscoveryNode node1;

@Mock
private DiscoveryNode node2;

@Mock
private IndexResponse indexResponse;

private TransportUploadModelAction transportUploadModelAction;

@Before
public void setup() {
MockitoAnnotations.openMocks(this);

transportUploadModelAction = new TransportUploadModelAction(
transportService,
actionFilters,
modelHelper,
mlIndicesHandler,
mlModelManager,
mlTaskManager,
clusterService,
threadPool,
client,
nodeFilter,
mlTaskDispatcher,
mlStats
);
assertNotNull(transportUploadModelAction);

MLStat mlStat = mock(MLStat.class);
when(mlStats.getStat(eq(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT))).thenReturn(mlStat);

doAnswer(invocation -> {
ActionListener<IndexResponse> listener = invocation.getArgument(1);
listener.onResponse(indexResponse);
return null;
}).when(mlTaskManager).createMLTask(any(), any());

doAnswer(invocation -> {
ActionListener<DiscoveryNode> listener = invocation.getArgument(0);
listener.onResponse(node1);
return null;
}).when(mlTaskDispatcher).dispatch(any());

when(clusterService.localNode()).thenReturn(node2);

doAnswer(invocation -> { return null; }).when(mlModelManager).uploadMLModel(any(), any());

}

public void testDoExecute_successWithLocalNodeEqualToClusterNode() {
when(node1.getId()).thenReturn("NodeId1");
when(node2.getId()).thenReturn("NodeId1");

transportUploadModelAction.doExecute(task, prepareRequest(), actionListener);
ArgumentCaptor<UploadModelResponse> argumentCaptor = ArgumentCaptor.forClass(UploadModelResponse.class);
verify(actionListener).onResponse(argumentCaptor.capture());
}

public void testDoExecute_successWithLocalNodeNotEqualToClusterNode() {
when(node1.getId()).thenReturn("NodeId1");
when(node2.getId()).thenReturn("NodeId2");
MLForwardResponse forwardResponse = Mockito.mock(MLForwardResponse.class);
doAnswer(invocation -> {
ActionListenerResponseHandler<MLForwardResponse> handler = invocation.getArgument(3);
handler.handleResponse(forwardResponse);
return null;
}).when(transportService).sendRequest(any(), any(), any(), any());

transportUploadModelAction.doExecute(task, prepareRequest(), actionListener);
ArgumentCaptor<UploadModelResponse> argumentCaptor = ArgumentCaptor.forClass(UploadModelResponse.class);
verify(actionListener).onResponse(argumentCaptor.capture());
}

public void testTransportUploadModelActionDoExecuteWithDispatchException() {
doAnswer(invocation -> {
ActionListener<Exception> listener = invocation.getArgument(0);
listener.onFailure(new Exception("Failed to dispatch upload model task "));
return null;
}).when(mlTaskDispatcher).dispatch(any());
transportUploadModelAction.doExecute(task, prepareRequest(), actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
}

public void testTransportUploadModelActionDoExecuteWithCreateTaskException() {
doAnswer(invocation -> {
ActionListener<Exception> listener = invocation.getArgument(1);
listener.onFailure(new Exception("Failed to create upload model task"));
return null;
}).when(mlTaskManager).createMLTask(any(), any());

transportUploadModelAction.doExecute(task, prepareRequest(), actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
}

private MLUploadModelRequest prepareRequest() {
MLUploadInput uploadInput = MLUploadInput
.builder()
.functionName(FunctionName.BATCH_RCF)
.loadModel(true)
.version("1.0")
.modelName("Test Model")
.modelConfig(
new TextEmbeddingModelConfig("CUSTOM", 123, TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS, "all config")
)
.modelFormat(MLModelFormat.TORCH_SCRIPT)
.url("Test URL")
.build();
return new MLUploadModelRequest(uploadInput);
}

}

0 comments on commit 99306dc

Please sign in to comment.