-
Notifications
You must be signed in to change notification settings - Fork 129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix MLModelTool returns null if the response of LLM is a pure json object #2655
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,7 @@ | |
import org.opensearch.ml.common.spi.tools.ToolAnnotation; | ||
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; | ||
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; | ||
import org.opensearch.ml.common.utils.StringUtils; | ||
import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting; | ||
|
||
import lombok.Getter; | ||
|
@@ -54,6 +55,7 @@ public class MLModelTool implements Tool { | |
private Parser inputParser; | ||
@Setter | ||
@Getter | ||
@VisibleForTesting | ||
private Parser outputParser; | ||
@Setter | ||
@Getter | ||
|
@@ -66,7 +68,13 @@ public MLModelTool(Client client, String modelId, String responseField) { | |
|
||
outputParser = o -> { | ||
List<ModelTensors> mlModelOutputs = (List<ModelTensors>) o; | ||
return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get(responseField); | ||
Map<String, ?> dataAsMap = mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap(); | ||
// Return the response field if it exists, otherwise return the whole response a json string. | ||
if (dataAsMap.containsKey(responseField)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Original code doesn't have check on dataAsMap which might cause NPE in edge case, it would be good to add a check here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see many places that access dataAsMap directly without any index checking, so I follow that rule. But I can add array index checking. |
||
return dataAsMap.get(responseField); | ||
} else { | ||
return StringUtils.toJson(dataAsMap); | ||
} | ||
}; | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make UT be able to access the private field
outputParser
and test whether its results are as expected.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But in the test you already declared:
Parser outputParser = new MLModelTool(client, "modelId", "response").getOutputParser();
How are you usingoutputParser
from the MLModelTool class then?