diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java index 8a33edca7d..1bcf6c9ef0 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java @@ -21,6 +21,7 @@ import org.opensearch.ml.common.spi.tools.ToolAnnotation; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting; import lombok.Getter; @@ -54,6 +55,7 @@ public class MLModelTool implements Tool { private Parser inputParser; @Setter @Getter + @VisibleForTesting private Parser outputParser; @Setter @Getter @@ -65,8 +67,18 @@ public MLModelTool(Client client, String modelId, String responseField) { this.responseField = responseField; outputParser = o -> { - List mlModelOutputs = (List) o; - return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get(responseField); + try { + List mlModelOutputs = (List) o; + Map dataAsMap = mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap(); + // Return the response field if it exists, otherwise return the whole response as json string. + if (dataAsMap.containsKey(responseField)) { + return dataAsMap.get(responseField); + } else { + return StringUtils.toJson(dataAsMap); + } + } catch (Exception e) { + throw new IllegalStateException("LLM returns wrong or empty tensors", e); + } }; } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java index f6b54b56be..3aa76cd554 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java @@ -124,7 +124,7 @@ public void testMLModelsWithDefaultOutputParserAndMalformedResponseField() throw tool.run(null, listener); future.join(); - assertEquals(null, future.get()); + assertEquals("{\"response\":\"response 1\",\"action\":\"action1\"}", future.get()); } @Test @@ -170,6 +170,26 @@ public void testOutputParserLambda() { assertEquals("testResponse", result); } + @Test + public void testOutputParserWithJsonResponse() { + Parser outputParser = new MLModelTool(client, "modelId", "response").getOutputParser(); + String expectedJson = "{\"key1\":\"value1\",\"key2\":\"value2\"}"; + + // Create a mock ModelTensors with json object + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("key1", "value1", "key2", "value2")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + Object result = outputParser.parse(mlModelTensorOutput.getMlModelOutputs()); + assertEquals(expectedJson, result); + + // Create a mock ModelTensors with response string + modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "{\"key1\":\"value1\",\"key2\":\"value2\"}")).build(); + modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + result = outputParser.parse(mlModelTensorOutput.getMlModelOutputs()); + assertEquals(expectedJson, result); + } + @Test public void testRunWithError() { // Mocking the client.execute to simulate an error