diff --git a/common/build.gradle b/common/build.gradle index 79077317a1..4d5cb95740 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -34,6 +34,7 @@ dependencies { exclude group: 'com.google.j2objc', module: 'j2objc-annotations' exclude group: 'com.google.guava', module: 'listenablefuture' } + compileOnly 'com.jayway.jsonpath:json-path:2.9.0' } lombok { diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index b9c69cd194..5bd3e3ecae 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -36,6 +36,8 @@ import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD; import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD; +import com.jayway.jsonpath.JsonPath; + @Log4j2 public class StringUtils { @@ -56,6 +58,7 @@ public class StringUtils { static { gson = new Gson(); } + public static final String TO_STRING_FUNCTION_NAME = ".toString()"; public static boolean isValidJsonString(String Json) { try { @@ -239,4 +242,89 @@ public static String getErrorMessage(String errorMessage, String modelId, Boolea return errorMessage + " Model ID: " + modelId; } } + + public static String obtainFieldNameFromJsonPath(String jsonPath) { + String[] parts = jsonPath.split("\\."); + + // Get the last part which is the field name + return parts[parts.length - 1]; + } + + public static String getJsonPath(String jsonPathWithSource) { + // Find the index of the first occurrence of "$." + int startIndex = jsonPathWithSource.indexOf("$."); + + // Extract the substring from the startIndex to the end of the input string + return (startIndex != -1) ? jsonPathWithSource.substring(startIndex) : jsonPathWithSource; + } + + /** + * Checks if the given input string matches the JSONPath format. + * + *

The JSONPath format is a way to navigate and extract data from JSON documents. + * It uses a syntax similar to XPath for XML documents. This method attempts to compile + * the input string as a JSONPath expression using the {@link com.jayway.jsonpath.JsonPath} + * library. If the compilation succeeds, it means the input string is a valid JSONPath + * expression. + * + * @param input the input string to be checked for JSONPath format validity + * @return true if the input string is a valid JSONPath expression, false otherwise + */ + public static boolean isValidJSONPath(String input) { + if (input == null || input.isBlank()) { + return false; + } + try { + JsonPath.compile(input); // This will throw an exception if the path is invalid + return true; + } catch (Exception e) { + return false; + } + } + + + /** + * Collects the prefixes of the toString() method calls present in the values of the given map. + * + * @param map A map containing key-value pairs where the values may contain toString() method calls. + * @return A list of prefixes for the toString() method calls found in the map values. + */ + public static List collectToStringPrefixes(Map map) { + List prefixes = new ArrayList<>(); + for (String key : map.keySet()) { + String value = map.get(key); + if (value != null) { + Pattern pattern = Pattern.compile("\\$\\{parameters\\.(.+?)\\.toString\\(\\)\\}"); + Matcher matcher = pattern.matcher(value); + while (matcher.find()) { + String prefix = matcher.group(1); + prefixes.add(prefix); + } + } + } + return prefixes; + } + + /** + * Parses the given parameters map and processes the values containing toString() method calls. + * + * @param parameters A map containing key-value pairs where the values may contain toString() method calls. + * @return A new map with the processed values for the toString() method calls. + */ + public static Map parseParameters(Map parameters) { + if (parameters != null) { + List toStringParametersPrefixes = collectToStringPrefixes(parameters); + + if (!toStringParametersPrefixes.isEmpty()) { + for (String prefix : toStringParametersPrefixes) { + String value = parameters.get(prefix); + if (value != null) { + parameters.put(prefix + TO_STRING_FUNCTION_NAME, processTextDoc(value)); + } + } + } + } + return parameters; + } + } diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java index 7a475a2b87..d7f2ae44b7 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java @@ -6,7 +6,6 @@ import static org.opensearch.ml.processor.InferenceProcessorAttributes.*; -import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.List; @@ -31,9 +30,7 @@ import org.opensearch.script.ScriptService; import org.opensearch.script.TemplateScript; -import com.jayway.jsonpath.Configuration; import com.jayway.jsonpath.JsonPath; -import com.jayway.jsonpath.Option; /** * MLInferenceIngestProcessor requires a modelId string to call model inferences @@ -57,11 +54,6 @@ public class MLInferenceIngestProcessor extends AbstractProcessor implements Mod // it can be overwritten using max_prediction_tasks when creating processor public static final int DEFAULT_MAX_PREDICTION_TASKS = 10; - private Configuration suppressExceptionConfiguration = Configuration - .builder() - .options(Option.SUPPRESS_EXCEPTIONS, Option.DEFAULT_PATH_LEAF_TO_NULL, Option.ALWAYS_RETURN_LIST) - .build(); - protected MLInferenceIngestProcessor( String modelId, List> inputMaps, @@ -243,24 +235,29 @@ private void getMappedModelInputFromDocuments( Object documentFieldValue = ingestDocument.getFieldValue(originalFieldPath, Object.class); String documentFieldValueAsString = toString(documentFieldValue); updateModelParameters(modelInputFieldName, documentFieldValueAsString, modelParameters); + return; } - // else when cannot find field path in document, try check for nested array using json path - else { - if (documentFieldName.contains(DOT_SYMBOL)) { - - Map sourceObject = ingestDocument.getSourceAndMetadata(); - ArrayList fieldValueList = JsonPath - .using(suppressExceptionConfiguration) - .parse(sourceObject) - .read(documentFieldName); - if (!fieldValueList.isEmpty()) { - updateModelParameters(modelInputFieldName, toString(fieldValueList), modelParameters); - } else if (!ignoreMissing) { - throw new IllegalArgumentException("cannot find field name defined from input map: " + documentFieldName); + // If the standard dot path fails, try to check for a nested array using JSON path + if (StringUtils.isValidJSONPath(documentFieldName)) { + Map sourceObject = ingestDocument.getSourceAndMetadata(); + Object fieldValue = JsonPath.using(suppressExceptionConfiguration).parse(sourceObject).read(documentFieldName); + + if (fieldValue != null) { + if (fieldValue instanceof List) { + List fieldValueList = (List) fieldValue; + if (!fieldValueList.isEmpty()) { + updateModelParameters(modelInputFieldName, toString(fieldValueList), modelParameters); + } else if (!ignoreMissing) { + throw new IllegalArgumentException("Cannot find field name defined from input map: " + documentFieldName); + } + } else { + updateModelParameters(modelInputFieldName, toString(fieldValue), modelParameters); } } else if (!ignoreMissing) { - throw new IllegalArgumentException("cannot find field name defined from input map: " + documentFieldName); + throw new IllegalArgumentException("Cannot find field name defined from input map: " + documentFieldName); } + } else { + throw new IllegalArgumentException("Cannot find field name defined from input map: " + documentFieldName); } } diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java index 54d4ef220b..56bf0dc11d 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java @@ -7,6 +7,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.*; +import static org.opensearch.ml.common.utils.StringUtils.toJson; import static org.opensearch.ml.processor.MLInferenceIngestProcessor.DEFAULT_OUTPUT_FIELD_NAME; import java.nio.ByteBuffer; @@ -26,11 +27,13 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.ingest.IngestDocument; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.MLResultDataType; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; import org.opensearch.script.ScriptService; @@ -109,7 +112,7 @@ public void testExecute_Exception() throws Exception { */ public void testExecute_nestedObjectStringDocumentSuccess() { - List> inputMap = getInputMapsForNestedObjectChunks("chunks.chunk"); + List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk"); MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, null, true, false); ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", Arrays.asList(1, 2, 3))).build(); @@ -122,13 +125,26 @@ public void testExecute_nestedObjectStringDocumentSuccess() { return null; }).when(client).execute(any(), any(), any()); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class); processor.execute(nestedObjectIngestDocument, handler); + // match output documents Map sourceAndMetadata = getNestedObjectWithAnotherNestedObjectSource(); sourceAndMetadata.put(DEFAULT_OUTPUT_FIELD_NAME, ImmutableMap.of("response", Arrays.asList(1, 2, 3))); IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>()); verify(handler).accept(eq(ingestDocument1), isNull()); assertEquals(nestedObjectIngestDocument, ingestDocument1); + + // match model input + verify(client, times(1)).execute(eq(MLPredictionTaskAction.INSTANCE), argCaptor.capture(), any()); + MLPredictionTaskRequest req = argCaptor.getValue(); + MLInput mlInput = req.getMlInput(); + RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + assertEquals( + toJson(inputDataSet.getParameters()), + "{\"inputs\":\"[{\\\"text\\\":[{\\\"chapter\\\":\\\"first chapter\\\",\\\"context\\\":\\\"this is first\\\"},{\\\"chapter\\\":\\\"first chapter\\\",\\\"context\\\":\\\"this is second\\\"}]},{\\\"text\\\":[{\\\"chapter\\\":\\\"second chapter\\\",\\\"context\\\":\\\"this is third\\\"},{\\\"chapter\\\":\\\"second chapter\\\",\\\"context\\\":\\\"this is fourth\\\"}]}]\"}" + ); + } /** @@ -149,6 +165,23 @@ public void testExecute_nestedObjectMapDocumentSuccess() { return null; }).when(client).execute(any(), any(), any()); + /** + * Preview of sourceAndMetadata + * { + * "chunks": [ + * { + * "chunk": { + * "text": "this is first" + * } + * }, + * { + * "chunk": { + * "text": "this is second" + * } + * } + * ] + * } + */ ArrayList childDocuments = new ArrayList<>(); Map childDocument1Text = new HashMap<>(); childDocument1Text.put("text", "this is first"); @@ -166,6 +199,8 @@ public void testExecute_nestedObjectMapDocumentSuccess() { Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put("chunks", childDocuments); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class); + IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); processor.execute(nestedObjectIngestDocument, handler); @@ -196,6 +231,13 @@ public void testExecute_nestedObjectMapDocumentSuccess() { IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>()); verify(handler).accept(eq(ingestDocument1), isNull()); assertEquals(nestedObjectIngestDocument, ingestDocument1); + + // match model input + verify(client, times(1)).execute(eq(MLPredictionTaskAction.INSTANCE), argCaptor.capture(), any()); + MLPredictionTaskRequest req = argCaptor.getValue(); + MLInput mlInput = req.getMlInput(); + RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + assertEquals(toJson(inputDataSet.getParameters()), "{\"inputs\":\"[\\\"this is first\\\",\\\"this is second\\\"]\"}"); } public void testExecute_jsonPathWithMissingLeaves() { @@ -220,7 +262,7 @@ public void testExecute_jsonPathWithMissingLeaves() { /** * test nested object document with array of Map, - * the value Object is a also a nested object, + * the value Object is also a nested object, */ public void testExecute_nestedObjectAndNestedObjectDocumentOutputInOneFieldSuccess() { List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); @@ -294,6 +336,7 @@ public void testExecute_nestedObjectAndNestedObjectDocumentOutputInArraySuccess( Map sourceAndMetadata = getNestedObjectWithAnotherNestedObjectSource(); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class); IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); processor.execute(nestedObjectIngestDocument, handler); @@ -302,6 +345,17 @@ public void testExecute_nestedObjectAndNestedObjectDocumentOutputInArraySuccess( assertEquals(nestedObjectIngestDocument.getFieldValue("chunks.0.chunk.text.1.embedding", Object.class), Arrays.asList(2)); assertEquals(nestedObjectIngestDocument.getFieldValue("chunks.1.chunk.text.0.embedding", Object.class), Arrays.asList(3)); assertEquals(nestedObjectIngestDocument.getFieldValue("chunks.1.chunk.text.1.embedding", Object.class), Arrays.asList(4)); + + // match model input + verify(client, times(1)).execute(eq(MLPredictionTaskAction.INSTANCE), argCaptor.capture(), any()); + MLPredictionTaskRequest req = argCaptor.getValue(); + MLInput mlInput = req.getMlInput(); + RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + assertEquals( + toJson(inputDataSet.getParameters()), + "{\"inputs\":\"[\\\"this is first\\\",\\\"this is second\\\",\\\"this is third\\\",\\\"this is fourth\\\"]\"}" + ); + } public void testExecute_nestedObjectAndNestedObjectDocumentOutputInArrayMissingLeaveSuccess() { @@ -652,6 +706,8 @@ public void testExecute_getModelOutputFieldWithFieldNameSuccess() { return null; }).when(client).execute(any(), any(), any()); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class); + processor.execute(ingestDocument, handler); Map sourceAndMetadata = new HashMap<>(); @@ -659,8 +715,64 @@ public void testExecute_getModelOutputFieldWithFieldNameSuccess() { sourceAndMetadata.put("key2", "value2"); sourceAndMetadata.put("classification", ImmutableMap.of("language", "en", "score", "0.9876")); IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>()); + + // match output verify(handler).accept(eq(ingestDocument1), isNull()); assertEquals(ingestDocument, ingestDocument1); + + // match model input + verify(client, times(1)).execute(eq(MLPredictionTaskAction.INSTANCE), argCaptor.capture(), any()); + MLPredictionTaskRequest req = argCaptor.getValue(); + MLInput mlInput = req.getMlInput(); + RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + assertEquals(toJson(inputDataSet.getParameters()), "{\"key1\":\"value1\",\"key2\":\"value2\"}"); + } + + public void testExecute_InputMapAndOutputMapSuccess() { + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("classification", "response"); + outputMap.add(output); + + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put("inputs", "key1"); + inputMap.add(input); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, false, false); + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", ImmutableMap.of("language", "en", "score", "0.9876"))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class); + + processor.execute(ingestDocument, handler); + + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", "value1"); + sourceAndMetadata.put("key2", "value2"); + sourceAndMetadata.put("classification", ImmutableMap.of("language", "en", "score", "0.9876")); + IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>()); + + // match output + verify(handler).accept(eq(ingestDocument1), isNull()); + assertEquals(ingestDocument, ingestDocument1); + + // match model input + verify(client, times(1)).execute(eq(MLPredictionTaskAction.INSTANCE), argCaptor.capture(), any()); + MLPredictionTaskRequest req = argCaptor.getValue(); + MLInput mlInput = req.getMlInput(); + RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + assertEquals(toJson(inputDataSet.getParameters()), "{\"inputs\":\"value1\"}"); } public void testExecute_getModelOutputFieldWithDotPathSuccess() { @@ -844,7 +956,7 @@ public void testExecute_documentNotExistedFieldNameException() { processor.execute(ingestDocument, handler); verify(handler) - .accept(eq(null), argThat(exception -> exception.getMessage().equals("cannot find field name defined from input map: key99"))); + .accept(eq(null), argThat(exception -> exception.getMessage().equals("Cannot find field name defined from input map: key99"))); } public void testExecute_nestedDocumentNotExistedFieldNameException() { @@ -859,7 +971,7 @@ public void testExecute_nestedDocumentNotExistedFieldNameException() { argThat( exception -> exception .getMessage() - .equals("cannot find field name defined from input map: chunks.*.chunk.text.*.context1") + .equals("Cannot find field name defined from input map: chunks.*.chunk.text.*.context1") ) ); } @@ -1044,6 +1156,18 @@ public void testParseGetDataInTensor_BooleanDataType() { } private static Map getNestedObjectWithAnotherNestedObjectSource() { + /** + * {chunks=[ + * {chunk={text=[ + * {context=this is first, chapter=first chapter}, + * {context=this is second, chapter=first chapter} + * ]}}, + * {chunk={text=[ + * {context=this is third, chapter=second chapter}, + * {context=this is fourth, chapter=second chapter} + * ]}} + * ]} + */ ArrayList childDocuments = new ArrayList<>(); Map childDocument1Text = new HashMap<>(); @@ -1070,7 +1194,7 @@ private static Map getNestedObjectWithAnotherNestedObjectSource( grandChildDocument3Text.put("chapter", "second chapter"); Map grandChildDocument4Text = new HashMap<>(); grandChildDocument4Text.put("context", "this is fourth"); - grandChildDocument4Text.put("chapter", "first chapter"); + grandChildDocument4Text.put("chapter", "second chapter"); grandChildDocuments2.add(grandChildDocument3Text); grandChildDocuments2.add(grandChildDocument4Text);