Skip to content

Commit

Permalink
broken deployment in integration test
Browse files Browse the repository at this point in the history
  • Loading branch information
maxhniebergall committed Jan 3, 2024
1 parent f1b6192 commit 27a7f51
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@
import org.junit.ClassRule;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.equalTo;
Expand Down Expand Up @@ -64,6 +67,109 @@ static String mockServiceModelConfig() {
""";
}

// basic model from xpack/ml/integration/PyTorchModelIT.java
static final String BASE_64_ENCODED_MODEL =
"UEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAAUAA4Ac2ltcGxlbW9kZWwvZGF0YS5wa2xGQgoAWlpaWlpaWlpaWoACY19fdG9yY2hfXwp"
+ "TdXBlclNpbXBsZQpxACmBfShYCAAAAHRyYWluaW5ncQGIdWJxAi5QSwcIXOpBBDQAAAA0AAAAUEsDBBQACAgIAAAAAAAAAAAAAAAAAA"
+ "AAAAAdAEEAc2ltcGxlbW9kZWwvY29kZS9fX3RvcmNoX18ucHlGQj0AWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaW"
+ "lpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWnWOMWvDMBCF9/yKI5MMrnHTQsHgjt2aJdlCEIp9SgWSTpykFvfXV1htaYds0nfv473Jqhjh"
+ "kAPywbhgUbzSnC02wwZAyqBYOUzIUUoY4XRe6SVr/Q8lVsYbf4UBLkS2kBk1aOIPxbOIaPVQtEQ8vUnZ/WlrSxTA+JCTNHMc4Ig+Ele"
+ "s+Jod+iR3N/jDDf74wxu4e/5+DmtE9mUyhdgFNq7bZ3ekehbruC6aTxS/c1rom6Z698WrEfIYxcn4JGTftLA7tzCnJeD41IJVC+U07k"
+ "umUHw3E47Vqh+xnULeFisYLx064mV8UTZibWFMmX0p23wBUEsHCE0EGH3yAAAAlwEAAFBLAwQUAAgICAAAAAAAAAAAAAAAAAAAAAAAJ"
+ "wA5AHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5LmRlYnVnX3BrbEZCNQBaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpa"
+ "WlpaWlpaWlpaWlpaWlpaWlpaWlpaWrWST0+DMBiHW6bOod/BGS94kKpo2Mwyox5x3pbgiXSAFtdR/nQu3IwHiZ9oX88CaeGu9tL0efq"
+ "+v8P7fmiGA1wgTgoIcECZQqe6vmYD6G4hAJOcB1E8NazTm+ELyzY4C3Q0z8MsRwF+j4JlQUPEEo5wjH0WB9hCNFqgpOCExZY5QnnEw7"
+ "ME+0v8GuaIs8wnKI7RigVrKkBzm0lh2OdjkeHllG28f066vK6SfEypF60S+vuYt4gjj2fYr/uPrSvRv356TepfJ9iWJRN0OaELQSZN3"
+ "FRPNbcP1PTSntMr0x0HzLZQjPYIEo3UaFeiISRKH0Mil+BE/dyT1m7tCBLwVO1MX4DK3bbuTlXuy8r71j5Aoho66udAoseOnrdVzx28"
+ "UFW6ROuO/lT6QKKyo79VU54emj9QSwcInsUTEDMBAAAFAwAAUEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAAZAAYAc2ltcGxlbW9kZWw"
+ "vY29uc3RhbnRzLnBrbEZCAgBaWoACKS5QSwcIbS8JVwQAAAAEAAAAUEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAATADsAc2ltcGxlbW"
+ "9kZWwvdmVyc2lvbkZCNwBaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaMwpQSwcI0"
+ "Z5nVQIAAAACAAAAUEsBAgAAAAAICAAAAAAAAFzqQQQ0AAAANAAAABQAAAAAAAAAAAAAAAAAAAAAAHNpbXBsZW1vZGVsL2RhdGEucGts"
+ "UEsBAgAAFAAICAgAAAAAAE0EGH3yAAAAlwEAAB0AAAAAAAAAAAAAAAAAhAAAAHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5UEs"
+ "BAgAAFAAICAgAAAAAAJ7FExAzAQAABQMAACcAAAAAAAAAAAAAAAAAAgIAAHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5LmRlYn"
+ "VnX3BrbFBLAQIAAAAACAgAAAAAAABtLwlXBAAAAAQAAAAZAAAAAAAAAAAAAAAAAMMDAABzaW1wbGVtb2RlbC9jb25zdGFudHMucGtsU"
+ "EsBAgAAAAAICAAAAAAAANGeZ1UCAAAAAgAAABMAAAAAAAAAAAAAAAAAFAQAAHNpbXBsZW1vZGVsL3ZlcnNpb25QSwYGLAAAAAAAAAAe"
+ "Ay0AAAAAAAAAAAAFAAAAAAAAAAUAAAAAAAAAagEAAAAAAACSBAAAAAAAAFBLBgcAAAAA/AUAAAAAAAABAAAAUEsFBgAAAAAFAAUAagE"
+ "AAJIEAAAAAA==";
static final long RAW_MODEL_SIZE; // size of the model before base64 encoding
static {
RAW_MODEL_SIZE = Base64.getDecoder().decode(BASE_64_ENCODED_MODEL).length;
}

protected void createBasicModel(String modelId) throws IOException {
createPassThroughModel(modelId);
putVocabulary(List.of("these", "are", "my", "words"), modelId);
putModelDefinition(modelId, BASE_64_ENCODED_MODEL, RAW_MODEL_SIZE);
}

private void putModelDefinition(String modelId, String base64EncodedModel, long unencodedModelSize) throws IOException {
Request request = new Request("PUT", "_ml/trained_models/" + modelId + "/definition/0");
String body = Strings.format("""
{"total_definition_length":%s,"definition": "%s","total_parts": 1}""", unencodedModelSize, base64EncodedModel);
request.setJsonEntity(body);
System.out.println("putModelDefiniton:" + client().performRequest(request));
}

private void createPassThroughModel(String modelId) throws IOException {
createPassThroughModel(modelId, 0, 0);
}

private void createPassThroughModel(String modelId, long perDeploymentMemoryBytes, long perAllocationMemoryBytes) throws IOException {
Request request = new Request("PUT", "/_ml/trained_models/" + modelId);
String metadata;
if (perDeploymentMemoryBytes > 0 && perAllocationMemoryBytes > 0) {
metadata = org.elasticsearch.core.Strings.format("""
"metadata": {
"per_deployment_memory_bytes": %d,
"per_allocation_memory_bytes": %d
},""", perDeploymentMemoryBytes, perAllocationMemoryBytes);
} else {
metadata = "";
}
request.setJsonEntity(org.elasticsearch.core.Strings.format("""
{
"description": "simple model for testing",
"model_type": "pytorch",
%s
"inference_config": {
"pass_through": {
"tokenization": {
"bert": {
"with_special_tokens": false
}
}
}
}
}""", metadata));
System.out.println("createPassThroughModel:" + client().performRequest(request));
}

private void putVocabulary(List<String> vocabulary, String modelId) throws IOException {
List<String> vocabularyWithPad = new ArrayList<>();
// vocabularyWithPad.add(BertTokenizer.PAD_TOKEN);
// vocabularyWithPad.add(BertTokenizer.UNKNOWN_TOKEN);
vocabularyWithPad.addAll(vocabulary);
String quotedWords = vocabularyWithPad.stream().map(s -> "\"" + s + "\"").collect(Collectors.joining(","));

Request request = new Request("PUT", "_ml/trained_models/" + modelId + "/vocabulary");
request.setJsonEntity(org.elasticsearch.core.Strings.format("""
{ "vocabulary": [%s] }
""", quotedWords));
System.out.println("putVocabulary:" + client().performRequest(request));
}

protected void startDeployment(String modelId) throws IOException {
Request request = new Request("POST", "_ml/trained_models/" + modelId + "/deployment/_start?wait_for=started&timeout=1m");
client().performRequest(request);
System.out.println("startDeployment:" + client().performRequest(request));

}

protected Response deleteModel(String modelId, boolean force) throws IOException {
Request request = new Request("DELETE", "/_ml/trained_models/" + modelId + "?force=" + force);
return client().performRequest(request);
}

protected Map<String, Object> putModel(String modelId, String modelConfig, TaskType taskType) throws IOException {
String endpoint = Strings.format("_inference/%s/%s", taskType, modelId);
var request = new Request("PUT", endpoint);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*
* This file was contributed to by a generative AI
*/

package org.elasticsearch.xpack.inference;

import org.elasticsearch.client.Request;
import org.elasticsearch.client.ResponseException;
import org.elasticsearch.inference.TaskType;
import org.junit.After;

import java.io.IOException;
import java.util.List;
Expand All @@ -19,6 +23,17 @@

public class InferenceCrudIT extends InferenceBaseRestTest {

@After
public void cleanup() throws Exception {
waitForPendingTasks(client());
}

public void testPutAndDeleteModel() throws IOException {
String modelId = "a_model_for_happy_case";
createBasicModel(modelId);
deleteModel(modelId, true);
}

@SuppressWarnings("unchecked")
public void testGet() throws IOException {
for (int i = 0; i < 5; i++) {
Expand Down Expand Up @@ -49,8 +64,8 @@ public void testGet() throws IOException {
}

public void testGetModelWithWrongTaskType() throws IOException {
putModel("sparse_embedding_model", mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
var e = expectThrows(ResponseException.class, () -> getModels("sparse_embedding_model", TaskType.TEXT_EMBEDDING));
putModel("sparse_embedding_model_one", mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
var e = expectThrows(ResponseException.class, () -> getModels("sparse_embedding_model_one", TaskType.TEXT_EMBEDDING));
assertThat(
e.getMessage(),
containsString("Requested task type [text_embedding] does not match the model's task type [sparse_embedding]")
Expand All @@ -59,11 +74,26 @@ public void testGetModelWithWrongTaskType() throws IOException {

@SuppressWarnings("unchecked")
public void testGetModelWithAnyTaskType() throws IOException {
String modelId = "sparse_embedding_model";
String modelId = "sparse_embedding_model_two";
putModel(modelId, mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
var singleModel = (List<Map<String, Object>>) getModels(modelId, TaskType.ANY).get("models");
System.out.println("MODEL" + singleModel);
assertEquals(modelId, singleModel.get(0).get("model_id"));
assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get(0).get("task_type"));
}

public void testPutModelFailsWhenTrainedModelWithIdAlreadyExists() throws Exception {
String modelId = "duplicate_model_id";
createBasicModel(modelId);
Request request = new Request("GET", "_ml/trained_models/_all");
System.out.println("99598" + client().performRequest(request));
startDeployment(modelId);

var e = expectThrows(ResponseException.class, () -> putModel(modelId, mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING));
assertEquals("Trained machine learning model [duplicate_model_id] already exists", e.getMessage());

// clean up
deleteModel(modelId, true);
}

}

0 comments on commit 27a7f51

Please sign in to comment.