Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MLModelTool returns null if the response of LLM is a pure json object #2655

Merged
merged 4 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting;

import lombok.Getter;
Expand Down Expand Up @@ -54,6 +55,7 @@ public class MLModelTool implements Tool {
private Parser inputParser;
@Setter
@Getter
@VisibleForTesting
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make UT be able to access the private field outputParser and test whether its results are as expected.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But in the test you already declared: Parser outputParser = new MLModelTool(client, "modelId", "response").getOutputParser(); How are you using outputParser from the MLModelTool class then?

private Parser outputParser;
@Setter
@Getter
Expand All @@ -66,7 +68,13 @@ public MLModelTool(Client client, String modelId, String responseField) {

outputParser = o -> {
List<ModelTensors> mlModelOutputs = (List<ModelTensors>) o;
return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get(responseField);
Map<String, ?> dataAsMap = mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap();
// Return the response field if it exists, otherwise return the whole response a json string.
if (dataAsMap.containsKey(responseField)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Original code doesn't have check on dataAsMap which might cause NPE in edge case, it would be good to add a check here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see many places that access dataAsMap directly without any index checking, so I follow that rule. But I can add array index checking.

return dataAsMap.get(responseField);
} else {
return StringUtils.toJson(dataAsMap);
}
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ public void testMLModelsWithDefaultOutputParserAndMalformedResponseField() throw
tool.run(null, listener);

future.join();
assertEquals(null, future.get());
assertEquals("{\"response\":\"response 1\",\"action\":\"action1\"}", future.get());
}

@Test
Expand Down Expand Up @@ -170,6 +170,26 @@ public void testOutputParserLambda() {
assertEquals("testResponse", result);
}

@Test
public void testOutputParserWithJsonResponse() {
Parser outputParser = new MLModelTool(client, "modelId", "response").getOutputParser();
String expectedJson = "{\"key1\":\"value1\",\"key2\":\"value2\"}";

// Create a mock ModelTensors with json object
ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("key1", "value1", "key2", "value2")).build();
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
Object result = outputParser.parse(mlModelTensorOutput.getMlModelOutputs());
assertEquals(expectedJson, result);

// Create a mock ModelTensors with response string
modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "{\"key1\":\"value1\",\"key2\":\"value2\"}")).build();
modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
result = outputParser.parse(mlModelTensorOutput.getMlModelOutputs());
assertEquals(expectedJson, result);
}

@Test
public void testRunWithError() {
// Mocking the client.execute to simulate an error
Expand Down
Loading