Skip to content

Commit

Permalink
Fix the flaky test due to m_l_limit_exceeded_exception (#150)
Browse files Browse the repository at this point in the history
* increase the CB threshold, delete model after test

Signed-off-by: zhichao-aws <[email protected]>

* add log

Signed-off-by: zhichao-aws <[email protected]>

* add wait time

Signed-off-by: zhichao-aws <[email protected]>

* enhancement: wait model undeploy before delete; refactor the wait response logic

Signed-off-by: zhichao-aws <[email protected]>

* modify ci yml

Signed-off-by: zhichao-aws <[email protected]>

---------

Signed-off-by: zhichao-aws <[email protected]>
  • Loading branch information
zhichao-aws authored Jan 31, 2024
1 parent cb64cca commit 0791c34
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 20 deletions.
14 changes: 4 additions & 10 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@ jobs:
needs: Get-CI-Image-Tag
strategy:
matrix:
java:
- 11
- 17
- 21.0.1
java: [11, 17, 21]
name: Build and Test skills plugin on Linux
runs-on: ubuntu-latest
container:
Expand All @@ -65,14 +62,14 @@ jobs:
./gradlew publishToMavenLocal"
- name: Upload Coverage Report
uses: codecov/codecov-action@v3
uses: codecov/codecov-action@v1
with:
token: ${{ secrets.CODECOV_TOKEN }}

build-MacOS:
strategy:
matrix:
java: [ 11, 17 ]
java: [11, 17, 21]

name: Build and Test skills Plugin on MacOS
needs: Get-CI-Image-Tag
Expand All @@ -97,10 +94,7 @@ jobs:
build-windows:
strategy:
matrix:
java:
- 11
- 17
- 21.0.1
java: [11, 17, 21]
name: Build and Test skills plugin on Windows
needs: Get-CI-Image-Tag
runs-on: windows-latest
Expand Down
52 changes: 42 additions & 10 deletions src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import org.apache.commons.lang3.StringUtils;
Expand All @@ -35,6 +36,7 @@
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
Expand All @@ -57,6 +59,7 @@ public void updateClusterSettings() {
updateClusterSettings("plugins.ml_commons.only_run_on_ml_node", false);
// default threshold for native circuit breaker is 90, it may be not enough on test runner machine
updateClusterSettings("plugins.ml_commons.native_memory_threshold", 100);
updateClusterSettings("plugins.ml_commons.jvm_heap_memory_threshold", 100);
updateClusterSettings("plugins.ml_commons.allow_registering_model_via_url", true);
}

Expand Down Expand Up @@ -123,26 +126,35 @@ protected String indexMonitor(String monitorAsJsonString) {
}

@SneakyThrows
protected Map<String, Object> waitTaskComplete(String taskId) {
protected Map<String, Object> waitResponseMeetingCondition(
String method,
String endpoint,
String jsonEntity,
Predicate<Map<String, Object>> condition
) {
for (int i = 0; i < MAX_TASK_RESULT_QUERY_TIME_IN_SECOND; i++) {
Response response = makeRequest(client(), "GET", "/_plugins/_ml/tasks/" + taskId, null, (String) null, null);
Response response = makeRequest(client(), method, endpoint, null, jsonEntity, null);
assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
Map<String, Object> responseInMap = parseResponseToMap(response);
String state = responseInMap.get(MLTask.STATE_FIELD).toString();
if (state.equals(MLTaskState.COMPLETED.toString())) {
if (condition.test(responseInMap)) {
return responseInMap;
}
if (state.equals(MLTaskState.FAILED.toString())
|| state.equals(MLTaskState.CANCELLED.toString())
|| state.equals(MLTaskState.COMPLETED_WITH_ERROR.toString())) {
fail("The task failed with state " + state);
}
logger.info("The " + i + "-th response: " + responseInMap.toString());
Thread.sleep(DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND);
}
fail("The task failed to complete after " + MAX_TASK_RESULT_QUERY_TIME_IN_SECOND + " seconds.");
fail("The response failed to meet condition after " + MAX_TASK_RESULT_QUERY_TIME_IN_SECOND + " seconds.");
return null;
}

@SneakyThrows
protected Map<String, Object> waitTaskComplete(String taskId) {
Predicate<Map<String, Object>> condition = responseInMap -> {
String state = responseInMap.get(MLTask.STATE_FIELD).toString();
return state.equals(MLTaskState.COMPLETED.toString());
};
return waitResponseMeetingCondition("GET", "/_plugins/_ml/tasks/" + taskId, (String) null, condition);
}

// Register the model then deploy it. Returns the model_id until the model is deployed
protected String registerModelThenDeploy(String requestBody) {
String registerModelTaskId = registerModel(requestBody);
Expand All @@ -153,6 +165,26 @@ protected String registerModelThenDeploy(String requestBody) {
return modelId;
}

@SneakyThrows
private void waitModelUndeployed(String modelId) {
Predicate<Map<String, Object>> condition = responseInMap -> {
String state = responseInMap.get(MLModel.MODEL_STATE_FIELD).toString();
return !state.equals(MLModelState.DEPLOYED.toString())
&& !state.equals(MLModelState.DEPLOYING.toString())
&& !state.equals(MLModelState.PARTIALLY_DEPLOYED.toString());
};
waitResponseMeetingCondition("GET", "/_plugins/_ml/models/" + modelId, (String) null, condition);
return;
}

@SneakyThrows
protected void deleteModel(String modelId) {
// need to undeploy first as model can be in use
makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_undeploy", null, (String) null, null);
waitModelUndeployed(modelId);
makeRequest(client(), "DELETE", "/_plugins/_ml/models/" + modelId, null, (String) null, null);
}

protected void createIndexWithConfiguration(String indexName, String indexConfiguration) throws Exception {
Response response = makeRequest(client(), "PUT", indexName, null, indexConfiguration, null);
Map<String, Object> responseInMap = parseResponseToMap(response);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ public void setUp() {
public void tearDown() {
super.tearDown();
deleteExternalIndices();
deleteModel(modelId);
}

public void testNeuralSparseSearchToolInFlowAgent() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ public void stopMockLLM() {
server.stop(1);
}

@After
public void deleteModel() {
deleteModel(modelId);
}

private String setUpConnector() {
String url = String.format(Locale.ROOT, "http://127.0.0.1:%d/invoke", server.getAddress().getPort());
return createConnector(
Expand Down

0 comments on commit 0791c34

Please sign in to comment.