Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forwarding port changes in 2.4 to main branch (add test coverage to transportUploadModelAction) #593

Merged
merged 1 commit into from
Nov 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@ List<String> jacocoExclusions = [
'org.opensearch.ml.model.MLModelManager',
'org.opensearch.ml.action.unload.TransportUnloadModelAction',
'org.opensearch.ml.action.forward.TransportForwardAction',
'org.opensearch.ml.action.upload.TransportUploadModelAction',
'org.opensearch.ml.rest.RestMLPredictionAction',
'org.opensearch.ml.rest.RestMLUploadModelAction',
'org.opensearch.ml.rest.RestMLUnloadModelAction',
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.action.upload;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.*;
import static org.mockito.Mockito.when;

import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionListenerResponseHandler;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.client.Client;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.transport.forward.MLForwardResponse;
import org.opensearch.ml.common.transport.upload.MLUploadInput;
import org.opensearch.ml.common.transport.upload.MLUploadModelRequest;
import org.opensearch.ml.common.transport.upload.UploadModelResponse;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.indices.MLIndicesHandler;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.stats.MLStat;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.task.MLTaskDispatcher;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.tasks.Task;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

public class TransportUploadModelActionTests extends OpenSearchTestCase {

@Mock
private TransportService transportService;

@Mock
private ModelHelper modelHelper;

@Mock
private MLIndicesHandler mlIndicesHandler;

@Mock
private MLModelManager mlModelManager;

@Mock
private MLTaskManager mlTaskManager;

@Mock
private ClusterService clusterService;

@Mock
private ThreadPool threadPool;

@Mock
private Client client;

@Mock
private DiscoveryNodeHelper nodeFilter;

@Mock
private MLTaskDispatcher mlTaskDispatcher;

@Mock
private MLStats mlStats;

@Mock
private ActionFilters actionFilters;

@Mock
private Task task;

@Mock
private ActionListener<UploadModelResponse> actionListener;

@Mock
private DiscoveryNode node1;

@Mock
private DiscoveryNode node2;

@Mock
private IndexResponse indexResponse;

private TransportUploadModelAction transportUploadModelAction;

@Before
public void setup() {
MockitoAnnotations.openMocks(this);

transportUploadModelAction = new TransportUploadModelAction(
transportService,
actionFilters,
modelHelper,
mlIndicesHandler,
mlModelManager,
mlTaskManager,
clusterService,
threadPool,
client,
nodeFilter,
mlTaskDispatcher,
mlStats
);
assertNotNull(transportUploadModelAction);

MLStat mlStat = mock(MLStat.class);
when(mlStats.getStat(eq(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT))).thenReturn(mlStat);

doAnswer(invocation -> {
ActionListener<IndexResponse> listener = invocation.getArgument(1);
listener.onResponse(indexResponse);
return null;
}).when(mlTaskManager).createMLTask(any(), any());

doAnswer(invocation -> {
ActionListener<DiscoveryNode> listener = invocation.getArgument(0);
listener.onResponse(node1);
return null;
}).when(mlTaskDispatcher).dispatch(any());

when(clusterService.localNode()).thenReturn(node2);

doAnswer(invocation -> { return null; }).when(mlModelManager).uploadMLModel(any(), any());

}

public void testDoExecute_successWithLocalNodeEqualToClusterNode() {
when(node1.getId()).thenReturn("NodeId1");
when(node2.getId()).thenReturn("NodeId1");

transportUploadModelAction.doExecute(task, prepareRequest(), actionListener);
ArgumentCaptor<UploadModelResponse> argumentCaptor = ArgumentCaptor.forClass(UploadModelResponse.class);
verify(actionListener).onResponse(argumentCaptor.capture());
}

public void testDoExecute_successWithLocalNodeNotEqualToClusterNode() {
when(node1.getId()).thenReturn("NodeId1");
when(node2.getId()).thenReturn("NodeId2");
MLForwardResponse forwardResponse = Mockito.mock(MLForwardResponse.class);
doAnswer(invocation -> {
ActionListenerResponseHandler<MLForwardResponse> handler = invocation.getArgument(3);
handler.handleResponse(forwardResponse);
return null;
}).when(transportService).sendRequest(any(), any(), any(), any());

transportUploadModelAction.doExecute(task, prepareRequest(), actionListener);
ArgumentCaptor<UploadModelResponse> argumentCaptor = ArgumentCaptor.forClass(UploadModelResponse.class);
verify(actionListener).onResponse(argumentCaptor.capture());
}

public void testTransportUploadModelActionDoExecuteWithDispatchException() {
doAnswer(invocation -> {
ActionListener<Exception> listener = invocation.getArgument(0);
listener.onFailure(new Exception("Failed to dispatch upload model task "));
return null;
}).when(mlTaskDispatcher).dispatch(any());
transportUploadModelAction.doExecute(task, prepareRequest(), actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
}

public void testTransportUploadModelActionDoExecuteWithCreateTaskException() {
doAnswer(invocation -> {
ActionListener<Exception> listener = invocation.getArgument(1);
listener.onFailure(new Exception("Failed to create upload model task"));
return null;
}).when(mlTaskManager).createMLTask(any(), any());

transportUploadModelAction.doExecute(task, prepareRequest(), actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
}

private MLUploadModelRequest prepareRequest() {
MLUploadInput uploadInput = MLUploadInput
.builder()
.functionName(FunctionName.BATCH_RCF)
.loadModel(true)
.version("1.0")
.modelName("Test Model")
.modelConfig(
new TextEmbeddingModelConfig("CUSTOM", 123, TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS, "all config")
)
.modelFormat(MLModelFormat.TORCH_SCRIPT)
.url("Test URL")
.build();
return new MLUploadModelRequest(uploadInput);
}

}