diff --git a/common/build.gradle b/common/build.gradle index 79077317a1..4d5cb95740 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -34,6 +34,7 @@ dependencies { exclude group: 'com.google.j2objc', module: 'j2objc-annotations' exclude group: 'com.google.guava', module: 'listenablefuture' } + compileOnly 'com.jayway.jsonpath:json-path:2.9.0' } lombok { diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index 9fab197a8c..7d9a890e3a 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -36,6 +36,8 @@ import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD; import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD; +import com.jayway.jsonpath.JsonPath; + @Log4j2 public class StringUtils { @@ -56,6 +58,7 @@ public class StringUtils { static { gson = new Gson(); } + public static final String TO_STRING_FUNCTION_NAME = ".toString()"; public static boolean isValidJsonString(String Json) { try { @@ -239,4 +242,89 @@ public static String getErrorMessage(String errorMessage, String modelId, Boolea return errorMessage + " Model ID: " + modelId; } } + + public static String obtainFieldNameFromJsonPath(String jsonPath) { + String[] parts = jsonPath.split("\\."); + + // Get the last part which is the field name + return parts[parts.length - 1]; + } + + public static String getJsonPath(String jsonPathWithSource) { + // Find the index of the first occurrence of "$." + int startIndex = jsonPathWithSource.indexOf("$."); + + // Extract the substring from the startIndex to the end of the input string + return (startIndex != -1) ? jsonPathWithSource.substring(startIndex) : jsonPathWithSource; + } + + /** + * Checks if the given input string matches the JSONPath format. + * + *

The JSONPath format is a way to navigate and extract data from JSON documents. + * It uses a syntax similar to XPath for XML documents. This method attempts to compile + * the input string as a JSONPath expression using the {@link com.jayway.jsonpath.JsonPath} + * library. If the compilation succeeds, it means the input string is a valid JSONPath + * expression. + * + * @param input the input string to be checked for JSONPath format validity + * @return true if the input string is a valid JSONPath expression, false otherwise + */ + public static boolean isValidJSONPath(String input) { + if (input == null || input.isBlank()) { + return false; + } + try { + JsonPath.compile(input); // This will throw an exception if the path is invalid + return true; + } catch (Exception e) { + return false; + } + } + + + /** + * Collects the prefixes of the toString() method calls present in the values of the given map. + * + * @param map A map containing key-value pairs where the values may contain toString() method calls. + * @return A list of prefixes for the toString() method calls found in the map values. + */ + public static List collectToStringPrefixes(Map map) { + List prefixes = new ArrayList<>(); + for (String key : map.keySet()) { + String value = map.get(key); + if (value != null) { + Pattern pattern = Pattern.compile("\\$\\{parameters\\.(.+?)\\.toString\\(\\)\\}"); + Matcher matcher = pattern.matcher(value); + while (matcher.find()) { + String prefix = matcher.group(1); + prefixes.add(prefix); + } + } + } + return prefixes; + } + + /** + * Parses the given parameters map and processes the values containing toString() method calls. + * + * @param parameters A map containing key-value pairs where the values may contain toString() method calls. + * @return A new map with the processed values for the toString() method calls. + */ + public static Map parseParameters(Map parameters) { + if (parameters != null) { + List toStringParametersPrefixes = collectToStringPrefixes(parameters); + + if (!toStringParametersPrefixes.isEmpty()) { + for (String prefix : toStringParametersPrefixes) { + String value = parameters.get(prefix); + if (value != null) { + parameters.put(prefix + TO_STRING_FUNCTION_NAME, processTextDoc(value)); + } + } + } + } + return parameters; + } + } diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java index 48e0464fbc..5424746d1a 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java @@ -7,7 +7,6 @@ import static org.opensearch.ml.processor.InferenceProcessorAttributes.*; import java.io.IOException; -import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; @@ -37,9 +36,7 @@ import org.opensearch.script.ScriptService; import org.opensearch.script.TemplateScript; -import com.jayway.jsonpath.Configuration; import com.jayway.jsonpath.JsonPath; -import com.jayway.jsonpath.Option; /** * MLInferenceIngestProcessor requires a modelId string to call model inferences @@ -75,11 +72,6 @@ public class MLInferenceIngestProcessor extends AbstractProcessor implements Mod public static final String DEFAULT_MODEl_INPUT = "{ \"parameters\": ${ml_inference.parameters} }"; private final NamedXContentRegistry xContentRegistry; - private Configuration suppressExceptionConfiguration = Configuration - .builder() - .options(Option.SUPPRESS_EXCEPTIONS, Option.DEFAULT_PATH_LEAF_TO_NULL, Option.ALWAYS_RETURN_LIST) - .build(); - protected MLInferenceIngestProcessor( String modelId, List> inputMaps, @@ -320,24 +312,29 @@ private void getMappedModelInputFromDocuments( Object documentFieldValue = ingestDocument.getFieldValue(originalFieldPath, Object.class); String documentFieldValueAsString = toString(documentFieldValue); updateModelParameters(modelInputFieldName, documentFieldValueAsString, modelParameters); + return; } - // else when cannot find field path in document, try check for nested array using json path - else { - if (documentFieldName.contains(DOT_SYMBOL)) { - - Map sourceObject = ingestDocument.getSourceAndMetadata(); - ArrayList fieldValueList = JsonPath - .using(suppressExceptionConfiguration) - .parse(sourceObject) - .read(documentFieldName); - if (!fieldValueList.isEmpty()) { - updateModelParameters(modelInputFieldName, toString(fieldValueList), modelParameters); - } else if (!ignoreMissing) { - throw new IllegalArgumentException("cannot find field name defined from input map: " + documentFieldName); + // If the standard dot path fails, try to check for a nested array using JSON path + if (StringUtils.isValidJSONPath(documentFieldName)) { + Map sourceObject = ingestDocument.getSourceAndMetadata(); + Object fieldValue = JsonPath.using(suppressExceptionConfiguration).parse(sourceObject).read(documentFieldName); + + if (fieldValue != null) { + if (fieldValue instanceof List) { + List fieldValueList = (List) fieldValue; + if (!fieldValueList.isEmpty()) { + updateModelParameters(modelInputFieldName, toString(fieldValueList), modelParameters); + } else if (!ignoreMissing) { + throw new IllegalArgumentException("Cannot find field name defined from input map: " + documentFieldName); + } + } else { + updateModelParameters(modelInputFieldName, toString(fieldValue), modelParameters); } } else if (!ignoreMissing) { - throw new IllegalArgumentException("cannot find field name defined from input map: " + documentFieldName); + throw new IllegalArgumentException("Cannot find field name defined from input map: " + documentFieldName); } + } else { + throw new IllegalArgumentException("Cannot find field name defined from input map: " + documentFieldName); } } diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java index 203392eb75..3ff5d957f3 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java @@ -7,6 +7,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.*; +import static org.opensearch.ml.common.utils.StringUtils.toJson; import static org.opensearch.ml.processor.MLInferenceIngestProcessor.DEFAULT_OUTPUT_FIELD_NAME; import java.io.IOException; @@ -31,11 +32,13 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ingest.IngestDocument; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.MLResultDataType; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; import org.opensearch.script.ScriptService; @@ -164,13 +167,26 @@ public void testExecute_nestedObjectStringDocumentSuccess() { return null; }).when(client).execute(any(), any(), any()); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class); processor.execute(nestedObjectIngestDocument, handler); + // match output documents Map sourceAndMetadata = getNestedObjectWithAnotherNestedObjectSource(); sourceAndMetadata.put(DEFAULT_OUTPUT_FIELD_NAME, ImmutableMap.of("response", Arrays.asList(1, 2, 3))); IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>()); verify(handler).accept(eq(ingestDocument1), isNull()); assertEquals(nestedObjectIngestDocument, ingestDocument1); + + // match model input + verify(client, times(1)).execute(eq(MLPredictionTaskAction.INSTANCE), argCaptor.capture(), any()); + MLPredictionTaskRequest req = argCaptor.getValue(); + MLInput mlInput = req.getMlInput(); + RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + assertEquals( + toJson(inputDataSet.getParameters()), + "{\"inputs\":\"[{\\\"text\\\":[{\\\"chapter\\\":\\\"first chapter\\\",\\\"context\\\":\\\"this is first\\\"},{\\\"chapter\\\":\\\"first chapter\\\",\\\"context\\\":\\\"this is second\\\"}]},{\\\"text\\\":[{\\\"chapter\\\":\\\"second chapter\\\",\\\"context\\\":\\\"this is third\\\"},{\\\"chapter\\\":\\\"second chapter\\\",\\\"context\\\":\\\"this is fourth\\\"}]}]\"}" + ); + } /** @@ -202,6 +218,23 @@ public void testExecute_nestedObjectMapDocumentSuccess() throws IOException { return null; }).when(client).execute(any(), any(), any()); + /** + * Preview of sourceAndMetadata + * { + * "chunks": [ + * { + * "chunk": { + * "text": "this is first" + * } + * }, + * { + * "chunk": { + * "text": "this is second" + * } + * } + * ] + * } + */ ArrayList childDocuments = new ArrayList<>(); Map childDocument1Text = new HashMap<>(); childDocument1Text.put("text", "this is first"); @@ -219,6 +252,8 @@ public void testExecute_nestedObjectMapDocumentSuccess() throws IOException { Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put("chunks", childDocuments); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class); + IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); processor.execute(nestedObjectIngestDocument, handler); @@ -250,6 +285,13 @@ public void testExecute_nestedObjectMapDocumentSuccess() throws IOException { IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>()); verify(handler).accept(eq(ingestDocument1), isNull()); assertEquals(nestedObjectIngestDocument, ingestDocument1); + + // match model input + verify(client, times(1)).execute(eq(MLPredictionTaskAction.INSTANCE), argCaptor.capture(), any()); + MLPredictionTaskRequest req = argCaptor.getValue(); + MLInput mlInput = req.getMlInput(); + RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + assertEquals(toJson(inputDataSet.getParameters()), "{\"inputs\":\"[\\\"this is first\\\",\\\"this is second\\\"]\"}"); } public void testExecute_jsonPathWithMissingLeaves() { @@ -274,7 +316,7 @@ public void testExecute_jsonPathWithMissingLeaves() { /** * test nested object document with array of Map, - * the value Object is a also a nested object, + * the value Object is also a nested object, */ public void testExecute_nestedObjectAndNestedObjectDocumentOutputInOneFieldSuccess() throws IOException { List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); @@ -371,6 +413,7 @@ public void testExecute_nestedObjectAndNestedObjectDocumentOutputInArraySuccess( Map sourceAndMetadata = getNestedObjectWithAnotherNestedObjectSource(); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class); IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); processor.execute(nestedObjectIngestDocument, handler); @@ -379,6 +422,17 @@ public void testExecute_nestedObjectAndNestedObjectDocumentOutputInArraySuccess( assertEquals(nestedObjectIngestDocument.getFieldValue("chunks.0.chunk.text.1.embedding", Object.class), Arrays.asList(2)); assertEquals(nestedObjectIngestDocument.getFieldValue("chunks.1.chunk.text.0.embedding", Object.class), Arrays.asList(3)); assertEquals(nestedObjectIngestDocument.getFieldValue("chunks.1.chunk.text.1.embedding", Object.class), Arrays.asList(4)); + + // match model input + verify(client, times(1)).execute(eq(MLPredictionTaskAction.INSTANCE), argCaptor.capture(), any()); + MLPredictionTaskRequest req = argCaptor.getValue(); + MLInput mlInput = req.getMlInput(); + RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + assertEquals( + toJson(inputDataSet.getParameters()), + "{\"inputs\":\"[\\\"this is first\\\",\\\"this is second\\\",\\\"this is third\\\",\\\"this is fourth\\\"]\"}" + ); + } public void testExecute_nestedObjectAndNestedObjectDocumentOutputInArrayMissingLeaveSuccess() { @@ -955,6 +1009,8 @@ public void testExecute_getModelOutputFieldWithFieldNameSuccess() { return null; }).when(client).execute(any(), any(), any()); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class); + processor.execute(ingestDocument, handler); Map sourceAndMetadata = new HashMap<>(); @@ -962,8 +1018,75 @@ public void testExecute_getModelOutputFieldWithFieldNameSuccess() { sourceAndMetadata.put("key2", "value2"); sourceAndMetadata.put("classification", ImmutableMap.of("language", "en", "score", "0.9876")); IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>()); + + // match output verify(handler).accept(eq(ingestDocument1), isNull()); assertEquals(ingestDocument, ingestDocument1); + + // match model input + verify(client, times(1)).execute(eq(MLPredictionTaskAction.INSTANCE), argCaptor.capture(), any()); + MLPredictionTaskRequest req = argCaptor.getValue(); + MLInput mlInput = req.getMlInput(); + RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + assertEquals(toJson(inputDataSet.getParameters()), "{\"key1\":\"value1\",\"key2\":\"value2\"}"); + } + + public void testExecute_InputMapAndOutputMapSuccess() { + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("classification", "response"); + outputMap.add(output); + + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put("inputs", "key1"); + inputMap.add(input); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + true, + "remote", + false, + false, + false, + null + ); + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", ImmutableMap.of("language", "en", "score", "0.9876"))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class); + + processor.execute(ingestDocument, handler); + + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", "value1"); + sourceAndMetadata.put("key2", "value2"); + sourceAndMetadata.put("classification", ImmutableMap.of("language", "en", "score", "0.9876")); + IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>()); + + // match output + verify(handler).accept(eq(ingestDocument1), isNull()); + assertEquals(ingestDocument, ingestDocument1); + + // match model input + verify(client, times(1)).execute(eq(MLPredictionTaskAction.INSTANCE), argCaptor.capture(), any()); + MLPredictionTaskRequest req = argCaptor.getValue(); + MLInput mlInput = req.getMlInput(); + RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + assertEquals(toJson(inputDataSet.getParameters()), "{\"inputs\":\"value1\"}"); } public void testExecute_getModelOutputFieldWithDotPathSuccess() { @@ -1209,7 +1332,7 @@ public void testExecute_documentNotExistedFieldNameException() { processor.execute(ingestDocument, handler); verify(handler) - .accept(eq(null), argThat(exception -> exception.getMessage().equals("cannot find field name defined from input map: key99"))); + .accept(eq(null), argThat(exception -> exception.getMessage().equals("Cannot find field name defined from input map: key99"))); } public void testExecute_nestedDocumentNotExistedFieldNameException() { @@ -1235,7 +1358,7 @@ public void testExecute_nestedDocumentNotExistedFieldNameException() { argThat( exception -> exception .getMessage() - .equals("cannot find field name defined from input map: chunks.*.chunk.text.*.context1") + .equals("Cannot find field name defined from input map: chunks.*.chunk.text.*.context1") ) ); } @@ -1613,7 +1736,40 @@ public void testExecute_localModelSuccess() { updatedBooks.add(updatedBook2); sourceAndMetadata.put("books", updatedBooks); - IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>()); + // match meta data + Map expectedIngestMetadata = new HashMap<>(); + Map valueMap = new HashMap<>(); + Map titleEmbeddingMap = new HashMap<>(); + List> inferenceResultsList = new ArrayList<>(); + + Map expectedOutputMap = new HashMap<>(); + List> outputList = new ArrayList<>(); + Map dataMap = new HashMap<>(); + dataMap.put("data", Arrays.asList(1.0, 2.0, 3.0, 4.0)); + dataMap.put("name", "sentence_embedding"); + + Map inferenceResultMap = new HashMap<>(); + List> outputListInner = new ArrayList<>(); + outputListInner.add(dataMap); + inferenceResultMap.put("output", outputListInner); + + Map dataAsMapMap = new HashMap<>(); + List> inferenceResultsListInner = new ArrayList<>(); + inferenceResultsListInner.add(inferenceResultMap); + dataAsMapMap.put("inference_results", inferenceResultsListInner); + + Map expectedDataAsMap = new HashMap<>(); + expectedDataAsMap.put("dataAsMap", dataAsMapMap); + outputList.add(expectedDataAsMap); + expectedOutputMap.put("output", outputList); + inferenceResultsList.add(expectedOutputMap); + + titleEmbeddingMap.put("inference_results", inferenceResultsList); + valueMap.put("title_embedding", titleEmbeddingMap); + expectedIngestMetadata.put("_value", valueMap); + + IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, expectedIngestMetadata); + System.out.println(ingestDocument1); verify(handler).accept(eq(ingestDocument1), isNull()); assertEquals(nestedObjectIngestDocument, ingestDocument1); } @@ -1810,6 +1966,18 @@ public void testWriteNewDotPathForNestedObject() { } private static Map getNestedObjectWithAnotherNestedObjectSource() { + /** + * {chunks=[ + * {chunk={text=[ + * {context=this is first, chapter=first chapter}, + * {context=this is second, chapter=first chapter} + * ]}}, + * {chunk={text=[ + * {context=this is third, chapter=second chapter}, + * {context=this is fourth, chapter=second chapter} + * ]}} + * ]} + */ ArrayList childDocuments = new ArrayList<>(); Map childDocument1Text = new HashMap<>(); @@ -1836,7 +2004,7 @@ private static Map getNestedObjectWithAnotherNestedObjectSource( grandChildDocument3Text.put("chapter", "second chapter"); Map grandChildDocument4Text = new HashMap<>(); grandChildDocument4Text.put("context", "this is fourth"); - grandChildDocument4Text.put("chapter", "first chapter"); + grandChildDocument4Text.put("chapter", "second chapter"); grandChildDocuments2.add(grandChildDocument3Text); grandChildDocuments2.add(grandChildDocument4Text);