From b3b0e871f8c47a5fc17f52afe5a846550c9e1f5b Mon Sep 17 00:00:00 2001 From: Heng Qian Date: Tue, 16 Jul 2024 18:44:12 +0800 Subject: [PATCH 1/4] Fix MLModelTool returns null if the response of LLM is a pure json object Signed-off-by: Heng Qian --- .../ml/engine/tools/MLModelTool.java | 9 ++++++++- .../ml/engine/tools/MLModelToolTests.java | 20 +++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java index 8a33edca7d..1033b7df7a 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java @@ -21,6 +21,7 @@ import org.opensearch.ml.common.spi.tools.ToolAnnotation; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting; import lombok.Getter; @@ -54,6 +55,7 @@ public class MLModelTool implements Tool { private Parser inputParser; @Setter @Getter + @VisibleForTesting private Parser outputParser; @Setter @Getter @@ -66,7 +68,12 @@ public MLModelTool(Client client, String modelId, String responseField) { outputParser = o -> { List mlModelOutputs = (List) o; - return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get(responseField); + Map dataAsMap = mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap(); + if (dataAsMap.size() == 1 && dataAsMap.containsKey(responseField)) { + return dataAsMap.get(responseField); + } else { + return StringUtils.toJson(dataAsMap); + } }; } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java index f6b54b56be..0fcf2d2ab5 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java @@ -170,6 +170,26 @@ public void testOutputParserLambda() { assertEquals("testResponse", result); } + @Test + public void testOutputParserWithJsonResponse() { + Parser outputParser = new MLModelTool(client, "modelId", "response").getOutputParser(); + String expectedJson = "{\"key1\":\"value1\",\"key2\":\"value2\"}"; + + // Create a mock ModelTensors with json object + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("key1", "value1", "key2", "value2")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + Object result = outputParser.parse(mlModelTensorOutput.getMlModelOutputs()); + assertEquals(expectedJson, result); + + // Create a mock ModelTensors with response string + modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "{\"key1\":\"value1\",\"key2\":\"value2\"}")).build(); + modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + result = outputParser.parse(mlModelTensorOutput.getMlModelOutputs()); + assertEquals(expectedJson, result); + } + @Test public void testRunWithError() { // Mocking the client.execute to simulate an error From c28de3b444f610f3b4ab3ec008169af5de7b1e29 Mon Sep 17 00:00:00 2001 From: Heng Qian Date: Tue, 16 Jul 2024 19:02:18 +0800 Subject: [PATCH 2/4] Fix UT failure Signed-off-by: Heng Qian --- .../main/java/org/opensearch/ml/engine/tools/MLModelTool.java | 3 ++- .../java/org/opensearch/ml/engine/tools/MLModelToolTests.java | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java index 1033b7df7a..12db49891e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java @@ -69,7 +69,8 @@ public MLModelTool(Client client, String modelId, String responseField) { outputParser = o -> { List mlModelOutputs = (List) o; Map dataAsMap = mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap(); - if (dataAsMap.size() == 1 && dataAsMap.containsKey(responseField)) { + // Return the response field if it exists, otherwise return the whole response a json string. + if (dataAsMap.containsKey(responseField)) { return dataAsMap.get(responseField); } else { return StringUtils.toJson(dataAsMap); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java index 0fcf2d2ab5..3aa76cd554 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java @@ -124,7 +124,7 @@ public void testMLModelsWithDefaultOutputParserAndMalformedResponseField() throw tool.run(null, listener); future.join(); - assertEquals(null, future.get()); + assertEquals("{\"response\":\"response 1\",\"action\":\"action1\"}", future.get()); } @Test From 39d1bcfeb04c4c8573f741349551675b10c460a0 Mon Sep 17 00:00:00 2001 From: Heng Qian Date: Wed, 17 Jul 2024 17:49:26 +0800 Subject: [PATCH 3/4] Avoid NPE Signed-off-by: Heng Qian --- .../ml/engine/tools/MLModelTool.java | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java index 12db49891e..a7b209ac37 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java @@ -8,6 +8,7 @@ import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import org.opensearch.action.ActionRequest; import org.opensearch.client.Client; import org.opensearch.core.action.ActionListener; @@ -67,13 +68,18 @@ public MLModelTool(Client client, String modelId, String responseField) { this.responseField = responseField; outputParser = o -> { - List mlModelOutputs = (List) o; - Map dataAsMap = mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap(); - // Return the response field if it exists, otherwise return the whole response a json string. - if (dataAsMap.containsKey(responseField)) { - return dataAsMap.get(responseField); - } else { - return StringUtils.toJson(dataAsMap); + try { + List mlModelOutputs = (List) o; + Map dataAsMap = mlModelOutputs.getFirst().getMlModelTensors().getFirst() + .getDataAsMap(); + // Return the response field if it exists, otherwise return the whole response as json string. + if (dataAsMap.containsKey(responseField)) { + return dataAsMap.get(responseField); + } else { + return StringUtils.toJson(dataAsMap); + } + } catch (Exception e) { + throw new IllegalStateException("LLM returns wrong or empty tensors", e); } }; } From 76a2af94ae92ffe1bb66b78f335a71a5cb9feade Mon Sep 17 00:00:00 2001 From: Heng Qian Date: Wed, 17 Jul 2024 17:51:38 +0800 Subject: [PATCH 4/4] spotlessApply Signed-off-by: Heng Qian --- .../main/java/org/opensearch/ml/engine/tools/MLModelTool.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java index a7b209ac37..e943ea0ebc 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java @@ -8,7 +8,6 @@ import java.util.List; import java.util.Map; -import java.util.NoSuchElementException; import org.opensearch.action.ActionRequest; import org.opensearch.client.Client; import org.opensearch.core.action.ActionListener; @@ -70,8 +69,7 @@ public MLModelTool(Client client, String modelId, String responseField) { outputParser = o -> { try { List mlModelOutputs = (List) o; - Map dataAsMap = mlModelOutputs.getFirst().getMlModelTensors().getFirst() - .getDataAsMap(); + Map dataAsMap = mlModelOutputs.getFirst().getMlModelTensors().getFirst().getDataAsMap(); // Return the response field if it exists, otherwise return the whole response as json string. if (dataAsMap.containsKey(responseField)) { return dataAsMap.get(responseField);