diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index 1578e03608e82..01ef659332939 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -21,8 +21,11 @@ import org.junit.ClassRule; import java.io.IOException; +import java.util.ArrayList; +import java.util.Base64; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.equalTo; @@ -64,6 +67,109 @@ static String mockServiceModelConfig() { """; } + // basic model from xpack/ml/integration/PyTorchModelIT.java + static final String BASE_64_ENCODED_MODEL = + "UEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAAUAA4Ac2ltcGxlbW9kZWwvZGF0YS5wa2xGQgoAWlpaWlpaWlpaWoACY19fdG9yY2hfXwp" + + "TdXBlclNpbXBsZQpxACmBfShYCAAAAHRyYWluaW5ncQGIdWJxAi5QSwcIXOpBBDQAAAA0AAAAUEsDBBQACAgIAAAAAAAAAAAAAAAAAA" + + "AAAAAdAEEAc2ltcGxlbW9kZWwvY29kZS9fX3RvcmNoX18ucHlGQj0AWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaW" + + "lpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWnWOMWvDMBCF9/yKI5MMrnHTQsHgjt2aJdlCEIp9SgWSTpykFvfXV1htaYds0nfv473Jqhjh" + + "kAPywbhgUbzSnC02wwZAyqBYOUzIUUoY4XRe6SVr/Q8lVsYbf4UBLkS2kBk1aOIPxbOIaPVQtEQ8vUnZ/WlrSxTA+JCTNHMc4Ig+Ele" + + "s+Jod+iR3N/jDDf74wxu4e/5+DmtE9mUyhdgFNq7bZ3ekehbruC6aTxS/c1rom6Z698WrEfIYxcn4JGTftLA7tzCnJeD41IJVC+U07k" + + "umUHw3E47Vqh+xnULeFisYLx064mV8UTZibWFMmX0p23wBUEsHCE0EGH3yAAAAlwEAAFBLAwQUAAgICAAAAAAAAAAAAAAAAAAAAAAAJ" + + "wA5AHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5LmRlYnVnX3BrbEZCNQBaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpa" + + "WlpaWlpaWlpaWlpaWlpaWlpaWlpaWrWST0+DMBiHW6bOod/BGS94kKpo2Mwyox5x3pbgiXSAFtdR/nQu3IwHiZ9oX88CaeGu9tL0efq" + + "+v8P7fmiGA1wgTgoIcECZQqe6vmYD6G4hAJOcB1E8NazTm+ELyzY4C3Q0z8MsRwF+j4JlQUPEEo5wjH0WB9hCNFqgpOCExZY5QnnEw7" + + "ME+0v8GuaIs8wnKI7RigVrKkBzm0lh2OdjkeHllG28f066vK6SfEypF60S+vuYt4gjj2fYr/uPrSvRv356TepfJ9iWJRN0OaELQSZN3" + + "FRPNbcP1PTSntMr0x0HzLZQjPYIEo3UaFeiISRKH0Mil+BE/dyT1m7tCBLwVO1MX4DK3bbuTlXuy8r71j5Aoho66udAoseOnrdVzx28" + + "UFW6ROuO/lT6QKKyo79VU54emj9QSwcInsUTEDMBAAAFAwAAUEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAAZAAYAc2ltcGxlbW9kZWw" + + "vY29uc3RhbnRzLnBrbEZCAgBaWoACKS5QSwcIbS8JVwQAAAAEAAAAUEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAATADsAc2ltcGxlbW" + + "9kZWwvdmVyc2lvbkZCNwBaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaMwpQSwcI0" + + "Z5nVQIAAAACAAAAUEsBAgAAAAAICAAAAAAAAFzqQQQ0AAAANAAAABQAAAAAAAAAAAAAAAAAAAAAAHNpbXBsZW1vZGVsL2RhdGEucGts" + + "UEsBAgAAFAAICAgAAAAAAE0EGH3yAAAAlwEAAB0AAAAAAAAAAAAAAAAAhAAAAHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5UEs" + + "BAgAAFAAICAgAAAAAAJ7FExAzAQAABQMAACcAAAAAAAAAAAAAAAAAAgIAAHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5LmRlYn" + + "VnX3BrbFBLAQIAAAAACAgAAAAAAABtLwlXBAAAAAQAAAAZAAAAAAAAAAAAAAAAAMMDAABzaW1wbGVtb2RlbC9jb25zdGFudHMucGtsU" + + "EsBAgAAAAAICAAAAAAAANGeZ1UCAAAAAgAAABMAAAAAAAAAAAAAAAAAFAQAAHNpbXBsZW1vZGVsL3ZlcnNpb25QSwYGLAAAAAAAAAAe" + + "Ay0AAAAAAAAAAAAFAAAAAAAAAAUAAAAAAAAAagEAAAAAAACSBAAAAAAAAFBLBgcAAAAA/AUAAAAAAAABAAAAUEsFBgAAAAAFAAUAagE" + + "AAJIEAAAAAA=="; + static final long RAW_MODEL_SIZE; // size of the model before base64 encoding + static { + RAW_MODEL_SIZE = Base64.getDecoder().decode(BASE_64_ENCODED_MODEL).length; + } + + protected void createBasicModel(String modelId) throws IOException { + createPassThroughModel(modelId); + putVocabulary(List.of("these", "are", "my", "words"), modelId); + putModelDefinition(modelId, BASE_64_ENCODED_MODEL, RAW_MODEL_SIZE); + } + + private void putModelDefinition(String modelId, String base64EncodedModel, long unencodedModelSize) throws IOException { + Request request = new Request("PUT", "_ml/trained_models/" + modelId + "/definition/0"); + String body = Strings.format(""" + {"total_definition_length":%s,"definition": "%s","total_parts": 1}""", unencodedModelSize, base64EncodedModel); + request.setJsonEntity(body); + System.out.println("putModelDefiniton:" + client().performRequest(request)); + } + + private void createPassThroughModel(String modelId) throws IOException { + createPassThroughModel(modelId, 0, 0); + } + + private void createPassThroughModel(String modelId, long perDeploymentMemoryBytes, long perAllocationMemoryBytes) throws IOException { + Request request = new Request("PUT", "/_ml/trained_models/" + modelId); + String metadata; + if (perDeploymentMemoryBytes > 0 && perAllocationMemoryBytes > 0) { + metadata = org.elasticsearch.core.Strings.format(""" + "metadata": { + "per_deployment_memory_bytes": %d, + "per_allocation_memory_bytes": %d + },""", perDeploymentMemoryBytes, perAllocationMemoryBytes); + } else { + metadata = ""; + } + request.setJsonEntity(org.elasticsearch.core.Strings.format(""" + { + "description": "simple model for testing", + "model_type": "pytorch", + %s + "inference_config": { + "pass_through": { + "tokenization": { + "bert": { + "with_special_tokens": false + } + } + } + } + }""", metadata)); + System.out.println("createPassThroughModel:" + client().performRequest(request)); + } + + private void putVocabulary(List vocabulary, String modelId) throws IOException { + List vocabularyWithPad = new ArrayList<>(); + // vocabularyWithPad.add(BertTokenizer.PAD_TOKEN); + // vocabularyWithPad.add(BertTokenizer.UNKNOWN_TOKEN); + vocabularyWithPad.addAll(vocabulary); + String quotedWords = vocabularyWithPad.stream().map(s -> "\"" + s + "\"").collect(Collectors.joining(",")); + + Request request = new Request("PUT", "_ml/trained_models/" + modelId + "/vocabulary"); + request.setJsonEntity(org.elasticsearch.core.Strings.format(""" + { "vocabulary": [%s] } + """, quotedWords)); + System.out.println("putVocabulary:" + client().performRequest(request)); + } + + protected void startDeployment(String modelId) throws IOException { + Request request = new Request("POST", "_ml/trained_models/" + modelId + "/deployment/_start?wait_for=started&timeout=1m"); + client().performRequest(request); + System.out.println("startDeployment:" + client().performRequest(request)); + + } + + protected Response deleteModel(String modelId, boolean force) throws IOException { + Request request = new Request("DELETE", "/_ml/trained_models/" + modelId + "?force=" + force); + return client().performRequest(request); + } + protected Map putModel(String modelId, String modelConfig, TaskType taskType) throws IOException { String endpoint = Strings.format("_inference/%s/%s", taskType, modelId); var request = new Request("PUT", endpoint); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index 61278fcae6d94..57a842debeaa9 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -3,12 +3,16 @@ * or more contributor license agreements. Licensed under the Elastic License * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. + * + * This file was contributed to by a generative AI */ package org.elasticsearch.xpack.inference; +import org.elasticsearch.client.Request; import org.elasticsearch.client.ResponseException; import org.elasticsearch.inference.TaskType; +import org.junit.After; import java.io.IOException; import java.util.List; @@ -19,6 +23,17 @@ public class InferenceCrudIT extends InferenceBaseRestTest { + @After + public void cleanup() throws Exception { + waitForPendingTasks(client()); + } + + public void testPutAndDeleteModel() throws IOException { + String modelId = "a_model_for_happy_case"; + createBasicModel(modelId); + deleteModel(modelId, true); + } + @SuppressWarnings("unchecked") public void testGet() throws IOException { for (int i = 0; i < 5; i++) { @@ -49,8 +64,8 @@ public void testGet() throws IOException { } public void testGetModelWithWrongTaskType() throws IOException { - putModel("sparse_embedding_model", mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING); - var e = expectThrows(ResponseException.class, () -> getModels("sparse_embedding_model", TaskType.TEXT_EMBEDDING)); + putModel("sparse_embedding_model_one", mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING); + var e = expectThrows(ResponseException.class, () -> getModels("sparse_embedding_model_one", TaskType.TEXT_EMBEDDING)); assertThat( e.getMessage(), containsString("Requested task type [text_embedding] does not match the model's task type [sparse_embedding]") @@ -59,11 +74,26 @@ public void testGetModelWithWrongTaskType() throws IOException { @SuppressWarnings("unchecked") public void testGetModelWithAnyTaskType() throws IOException { - String modelId = "sparse_embedding_model"; + String modelId = "sparse_embedding_model_two"; putModel(modelId, mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING); var singleModel = (List>) getModels(modelId, TaskType.ANY).get("models"); System.out.println("MODEL" + singleModel); assertEquals(modelId, singleModel.get(0).get("model_id")); assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get(0).get("task_type")); } + + public void testPutModelFailsWhenTrainedModelWithIdAlreadyExists() throws Exception { + String modelId = "duplicate_model_id"; + createBasicModel(modelId); + Request request = new Request("GET", "_ml/trained_models/_all"); + System.out.println("99598" + client().performRequest(request)); + startDeployment(modelId); + + var e = expectThrows(ResponseException.class, () -> putModel(modelId, mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING)); + assertEquals("Trained machine learning model [duplicate_model_id] already exists", e.getMessage()); + + // clean up + deleteModel(modelId, true); + } + }