-
Notifications
You must be signed in to change notification settings - Fork 129
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
- Loading branch information
Showing
2 changed files
with
205 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
205 changes: 205 additions & 0 deletions
205
plugin/src/test/java/org/opensearch/ml/action/upload/TransportUploadModelActionTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.action.upload; | ||
|
||
import static org.mockito.ArgumentMatchers.any; | ||
import static org.mockito.ArgumentMatchers.eq; | ||
import static org.mockito.Mockito.*; | ||
import static org.mockito.Mockito.when; | ||
|
||
import org.junit.Before; | ||
import org.mockito.ArgumentCaptor; | ||
import org.mockito.Mock; | ||
import org.mockito.Mockito; | ||
import org.mockito.MockitoAnnotations; | ||
import org.opensearch.action.ActionListener; | ||
import org.opensearch.action.ActionListenerResponseHandler; | ||
import org.opensearch.action.index.IndexResponse; | ||
import org.opensearch.action.support.ActionFilters; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.cluster.node.DiscoveryNode; | ||
import org.opensearch.cluster.service.ClusterService; | ||
import org.opensearch.ml.cluster.DiscoveryNodeHelper; | ||
import org.opensearch.ml.common.FunctionName; | ||
import org.opensearch.ml.common.model.MLModelFormat; | ||
import org.opensearch.ml.common.model.TextEmbeddingModelConfig; | ||
import org.opensearch.ml.common.transport.forward.MLForwardResponse; | ||
import org.opensearch.ml.common.transport.upload.MLUploadInput; | ||
import org.opensearch.ml.common.transport.upload.MLUploadModelRequest; | ||
import org.opensearch.ml.common.transport.upload.UploadModelResponse; | ||
import org.opensearch.ml.engine.ModelHelper; | ||
import org.opensearch.ml.indices.MLIndicesHandler; | ||
import org.opensearch.ml.model.MLModelManager; | ||
import org.opensearch.ml.stats.MLNodeLevelStat; | ||
import org.opensearch.ml.stats.MLStat; | ||
import org.opensearch.ml.stats.MLStats; | ||
import org.opensearch.ml.task.MLTaskDispatcher; | ||
import org.opensearch.ml.task.MLTaskManager; | ||
import org.opensearch.tasks.Task; | ||
import org.opensearch.test.OpenSearchTestCase; | ||
import org.opensearch.threadpool.ThreadPool; | ||
import org.opensearch.transport.TransportService; | ||
|
||
public class TransportUploadModelActionTests extends OpenSearchTestCase { | ||
|
||
@Mock | ||
private TransportService transportService; | ||
|
||
@Mock | ||
private ModelHelper modelHelper; | ||
|
||
@Mock | ||
private MLIndicesHandler mlIndicesHandler; | ||
|
||
@Mock | ||
private MLModelManager mlModelManager; | ||
|
||
@Mock | ||
private MLTaskManager mlTaskManager; | ||
|
||
@Mock | ||
private ClusterService clusterService; | ||
|
||
@Mock | ||
private ThreadPool threadPool; | ||
|
||
@Mock | ||
private Client client; | ||
|
||
@Mock | ||
private DiscoveryNodeHelper nodeFilter; | ||
|
||
@Mock | ||
private MLTaskDispatcher mlTaskDispatcher; | ||
|
||
@Mock | ||
private MLStats mlStats; | ||
|
||
@Mock | ||
private ActionFilters actionFilters; | ||
|
||
@Mock | ||
private Task task; | ||
|
||
@Mock | ||
private ActionListener<UploadModelResponse> actionListener; | ||
|
||
@Mock | ||
private DiscoveryNode node1; | ||
|
||
@Mock | ||
private DiscoveryNode node2; | ||
|
||
@Mock | ||
private IndexResponse indexResponse; | ||
|
||
private TransportUploadModelAction transportUploadModelAction; | ||
|
||
@Before | ||
public void setup() { | ||
MockitoAnnotations.openMocks(this); | ||
|
||
transportUploadModelAction = new TransportUploadModelAction( | ||
transportService, | ||
actionFilters, | ||
modelHelper, | ||
mlIndicesHandler, | ||
mlModelManager, | ||
mlTaskManager, | ||
clusterService, | ||
threadPool, | ||
client, | ||
nodeFilter, | ||
mlTaskDispatcher, | ||
mlStats | ||
); | ||
assertNotNull(transportUploadModelAction); | ||
|
||
MLStat mlStat = mock(MLStat.class); | ||
when(mlStats.getStat(eq(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT))).thenReturn(mlStat); | ||
|
||
doAnswer(invocation -> { | ||
ActionListener<IndexResponse> listener = invocation.getArgument(1); | ||
listener.onResponse(indexResponse); | ||
return null; | ||
}).when(mlTaskManager).createMLTask(any(), any()); | ||
|
||
doAnswer(invocation -> { | ||
ActionListener<DiscoveryNode> listener = invocation.getArgument(0); | ||
listener.onResponse(node1); | ||
return null; | ||
}).when(mlTaskDispatcher).dispatch(any()); | ||
|
||
when(clusterService.localNode()).thenReturn(node2); | ||
|
||
doAnswer(invocation -> { return null; }).when(mlModelManager).uploadMLModel(any(), any()); | ||
|
||
} | ||
|
||
public void testDoExecute_successWithLocalNodeEqualToClusterNode() { | ||
when(node1.getId()).thenReturn("NodeId1"); | ||
when(node2.getId()).thenReturn("NodeId1"); | ||
|
||
transportUploadModelAction.doExecute(task, prepareRequest(), actionListener); | ||
ArgumentCaptor<UploadModelResponse> argumentCaptor = ArgumentCaptor.forClass(UploadModelResponse.class); | ||
verify(actionListener).onResponse(argumentCaptor.capture()); | ||
} | ||
|
||
public void testDoExecute_successWithLocalNodeNotEqualToClusterNode() { | ||
when(node1.getId()).thenReturn("NodeId1"); | ||
when(node2.getId()).thenReturn("NodeId2"); | ||
MLForwardResponse forwardResponse = Mockito.mock(MLForwardResponse.class); | ||
doAnswer(invocation -> { | ||
ActionListenerResponseHandler<MLForwardResponse> handler = invocation.getArgument(3); | ||
handler.handleResponse(forwardResponse); | ||
return null; | ||
}).when(transportService).sendRequest(any(), any(), any(), any()); | ||
|
||
transportUploadModelAction.doExecute(task, prepareRequest(), actionListener); | ||
ArgumentCaptor<UploadModelResponse> argumentCaptor = ArgumentCaptor.forClass(UploadModelResponse.class); | ||
verify(actionListener).onResponse(argumentCaptor.capture()); | ||
} | ||
|
||
public void testTransportUploadModelActionDoExecuteWithDispatchException() { | ||
doAnswer(invocation -> { | ||
ActionListener<Exception> listener = invocation.getArgument(0); | ||
listener.onFailure(new Exception("Failed to dispatch upload model task ")); | ||
return null; | ||
}).when(mlTaskDispatcher).dispatch(any()); | ||
transportUploadModelAction.doExecute(task, prepareRequest(), actionListener); | ||
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class); | ||
verify(actionListener).onFailure(argumentCaptor.capture()); | ||
} | ||
|
||
public void testTransportUploadModelActionDoExecuteWithCreateTaskException() { | ||
doAnswer(invocation -> { | ||
ActionListener<Exception> listener = invocation.getArgument(1); | ||
listener.onFailure(new Exception("Failed to create upload model task")); | ||
return null; | ||
}).when(mlTaskManager).createMLTask(any(), any()); | ||
|
||
transportUploadModelAction.doExecute(task, prepareRequest(), actionListener); | ||
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class); | ||
verify(actionListener).onFailure(argumentCaptor.capture()); | ||
} | ||
|
||
private MLUploadModelRequest prepareRequest() { | ||
MLUploadInput uploadInput = MLUploadInput | ||
.builder() | ||
.functionName(FunctionName.BATCH_RCF) | ||
.loadModel(true) | ||
.version("1.0") | ||
.modelName("Test Model") | ||
.modelConfig( | ||
new TextEmbeddingModelConfig("CUSTOM", 123, TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS, "all config") | ||
) | ||
.modelFormat(MLModelFormat.TORCH_SCRIPT) | ||
.url("Test URL") | ||
.build(); | ||
return new MLUploadModelRequest(uploadInput); | ||
} | ||
|
||
} |