Skip to content

Commit

Permalink
Delete model after each test in integ tests (#197)
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski authored Jun 7, 2023
1 parent 637bbe1 commit 12116a4
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase {
private static final int MAX_TASK_RESULT_QUERY_TIME_IN_SECOND = 60 * 5;

private static final int DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND = 1000;
private static final String DEFAULT_USER_AGENT = "Kibana";

protected final ClassLoader classLoader = this.getClass().getClassLoader();

Expand Down Expand Up @@ -93,7 +94,7 @@ protected String uploadModel(String requestBody) throws Exception {
"/_plugins/_ml/models/_upload",
null,
toHttpEntity(requestBody),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana"))
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
Map<String, Object> uploadResJson = XContentHelper.convertToMap(
XContentFactory.xContent(XContentType.JSON),
Expand Down Expand Up @@ -122,7 +123,7 @@ protected void loadModel(String modelId) throws Exception {
String.format(LOCALE, "/_plugins/_ml/models/%s/_load", modelId),
null,
toHttpEntity(""),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana"))
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
Map<String, Object> uploadResJson = XContentHelper.convertToMap(
XContentFactory.xContent(XContentType.JSON),
Expand Down Expand Up @@ -170,7 +171,7 @@ protected float[] runInference(String modelId, String queryText) {
String.format(LOCALE, "/_plugins/_ml/_predict/text_embedding/%s", modelId),
null,
toHttpEntity(String.format(LOCALE, "{\"text_docs\": [\"%s\"],\"target_response\": [\"sentence_embedding\"]}", queryText)),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana"))
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);

Map<String, Object> inferenceResJson = XContentHelper.convertToMap(
Expand Down Expand Up @@ -201,7 +202,7 @@ protected void createIndexWithConfiguration(String indexName, String indexConfig
indexName,
null,
toHttpEntity(indexConfiguration),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana"))
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
Map<String, Object> node = XContentHelper.convertToMap(
XContentFactory.xContent(XContentType.JSON),
Expand All @@ -225,7 +226,7 @@ protected void createPipelineProcessor(String modelId, String pipelineName) thro
modelId
)
),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana"))
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
Map<String, Object> node = XContentHelper.convertToMap(
XContentFactory.xContent(XContentType.JSON),
Expand Down Expand Up @@ -403,7 +404,7 @@ protected Map<String, Object> getTaskQueryResponse(String taskId) throws Excepti
String.format(LOCALE, "_plugins/_ml/tasks/%s", taskId),
null,
toHttpEntity(""),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana"))
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
return XContentHelper.convertToMap(
XContentFactory.xContent(XContentType.JSON),
Expand Down Expand Up @@ -491,4 +492,26 @@ protected static class KNNFieldConfig {
private final Integer dimension;
private final SpaceType spaceType;
}

@SneakyThrows
protected void deleteModel(String modelId) {
// need to undeploy first as model can be in use
makeRequest(
client(),
"POST",
String.format(LOCALE, "/_plugins/_ml/models/%s/_undeploy", modelId),
null,
toHttpEntity(""),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);

makeRequest(
client(),
"DELETE",
String.format(LOCALE, "/_plugins/_ml/models/%s", modelId),
null,
toHttpEntity(""),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
}
}
16 changes: 14 additions & 2 deletions src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
import static org.opensearch.neuralsearch.TestUtils.createRandomVector;
import static org.opensearch.neuralsearch.TestUtils.objectToFloat;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;

import lombok.SneakyThrows;

import org.junit.After;
import org.junit.Before;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.MatchAllQueryBuilder;
Expand Down Expand Up @@ -48,6 +48,17 @@ public void setUp() throws Exception {
modelId.compareAndSet(modelId.get(), prepareModel());
}

@After
@SneakyThrows
public void tearDown() {
super.tearDown();
/* this is required to minimize chance of model not being deployed due to open memory CB,
* this happens in case we leave model from previous test case. We use new model for every test, and old model
* can be undeployed and deleted to free resources after each test case execution.
*/
deleteModel(modelId.get());
}

/**
* Tests basic query:
* {
Expand Down Expand Up @@ -344,7 +355,8 @@ public void testFilterQuery() {
assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), 0.0);
}

private void initializeIndexIfNotExist(String indexName) throws IOException {
@SneakyThrows
private void initializeIndexIfNotExist(String indexName) {
if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) {
prepareKnnIndex(
TEST_BASIC_INDEX_NAME,
Expand Down

0 comments on commit 12116a4

Please sign in to comment.