Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.14] [Backport 2.x] Fix MLModelTool returns null if the response of LLM is a pure json object #2684

Merged
merged 1 commit into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting;

import lombok.Getter;
Expand Down Expand Up @@ -54,6 +55,7 @@ public class MLModelTool implements Tool {
private Parser inputParser;
@Setter
@Getter
@VisibleForTesting
private Parser outputParser;
@Setter
@Getter
Expand All @@ -65,8 +67,18 @@ public MLModelTool(Client client, String modelId, String responseField) {
this.responseField = responseField;

outputParser = o -> {
List<ModelTensors> mlModelOutputs = (List<ModelTensors>) o;
return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get(responseField);
try {
List<ModelTensors> mlModelOutputs = (List<ModelTensors>) o;
Map<String, ?> dataAsMap = mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap();
// Return the response field if it exists, otherwise return the whole response as json string.
if (dataAsMap.containsKey(responseField)) {
return dataAsMap.get(responseField);
} else {
return StringUtils.toJson(dataAsMap);
}
} catch (Exception e) {
throw new IllegalStateException("LLM returns wrong or empty tensors", e);
}
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ public void testMLModelsWithDefaultOutputParserAndMalformedResponseField() throw
tool.run(null, listener);

future.join();
assertEquals(null, future.get());
assertEquals("{\"response\":\"response 1\",\"action\":\"action1\"}", future.get());
}

@Test
Expand Down Expand Up @@ -170,6 +170,26 @@ public void testOutputParserLambda() {
assertEquals("testResponse", result);
}

@Test
public void testOutputParserWithJsonResponse() {
Parser outputParser = new MLModelTool(client, "modelId", "response").getOutputParser();
String expectedJson = "{\"key1\":\"value1\",\"key2\":\"value2\"}";

// Create a mock ModelTensors with json object
ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("key1", "value1", "key2", "value2")).build();
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
Object result = outputParser.parse(mlModelTensorOutput.getMlModelOutputs());
assertEquals(expectedJson, result);

// Create a mock ModelTensors with response string
modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "{\"key1\":\"value1\",\"key2\":\"value2\"}")).build();
modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
result = outputParser.parse(mlModelTensorOutput.getMlModelOutputs());
assertEquals(expectedJson, result);
}

@Test
public void testRunWithError() {
// Mocking the client.execute to simulate an error
Expand Down
Loading