From 56fc80a6ab74a520c5013b7c960c2ad37ff1df44 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Mon, 9 Dec 2024 13:24:09 +0800 Subject: [PATCH] Support different embedding types of model response Signed-off-by: zane-neo --- .../neuralsearch/common/VectorUtil.java | 4 +-- .../ml/MLCommonsClientAccessor.java | 29 ++++++++--------- .../TextImageEmbeddingProcessor.java | 4 +-- .../neuralsearch/common/VectorUtilTests.java | 6 ++-- .../ml/MLCommonsClientAccessorTests.java | 31 ++++++++++--------- .../query/NeuralQueryBuilderTests.java | 8 ++--- .../neuralsearch/BaseNeuralSearchIT.java | 2 +- 7 files changed, 44 insertions(+), 40 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/common/VectorUtil.java b/src/main/java/org/opensearch/neuralsearch/common/VectorUtil.java index bfbb2e6d9..5e5f5cd33 100644 --- a/src/main/java/org/opensearch/neuralsearch/common/VectorUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/common/VectorUtil.java @@ -21,10 +21,10 @@ public class VectorUtil { * @param vectorAsList {@link List} of {@link Float}'s representing the vector * @return array of floats produced from input list */ - public static float[] vectorAsListToArray(List vectorAsList) { + public static float[] vectorAsListToArray(List vectorAsList) { float[] vector = new float[vectorAsList.size()]; for (int i = 0; i < vectorAsList.size(); i++) { - vector[i] = vectorAsList.get(i); + vector[i] = vectorAsList.get(i).floatValue(); } return vector; } diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index f9ddf73a9..39ec5e243 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -52,7 +52,7 @@ public class MLCommonsClientAccessor { public void inferenceSentence( @NonNull final String modelId, @NonNull final String inputText, - @NonNull final ActionListener> listener + @NonNull final ActionListener> listener ) { inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, List.of(inputText), ActionListener.wrap(response -> { if (response.size() != 1) { @@ -82,7 +82,7 @@ public void inferenceSentence( public void inferenceSentences( @NonNull final String modelId, @NonNull final List inputText, - @NonNull final ActionListener>> listener + @NonNull final ActionListener>> listener ) { inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, inputText, listener); } @@ -103,7 +103,7 @@ public void inferenceSentences( @NonNull final List targetResponseFilters, @NonNull final String modelId, @NonNull final List inputText, - @NonNull final ActionListener>> listener + @NonNull final ActionListener>> listener ) { retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, 0, listener); } @@ -128,7 +128,7 @@ public void inferenceSentencesWithMapResult( public void inferenceSentences( @NonNull final String modelId, @NonNull final Map inputObjects, - @NonNull final ActionListener> listener + @NonNull final ActionListener> listener ) { retryableInferenceSentencesWithSingleVectorResult(TARGET_RESPONSE_FILTERS, modelId, inputObjects, 0, listener); } @@ -177,11 +177,11 @@ private void retryableInferenceSentencesWithVectorResult( final String modelId, final List inputText, final int retryTime, - final ActionListener>> listener + final ActionListener>> listener ) { MLInput mlInput = createMLTextInput(targetResponseFilters, inputText); mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { - final List> vector = buildVectorFromResponse(mlOutput); + final List> vector = buildVectorFromResponse(mlOutput); listener.onResponse(vector); }, e -> { if (RetryUtil.shouldRetry(e, retryTime)) { @@ -202,7 +202,8 @@ private void retryableInferenceSimilarityWithVectorResult( ) { MLInput mlInput = createMLTextPairsInput(queryText, inputText); mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { - final List scores = buildVectorFromResponse(mlOutput).stream().map(v -> v.get(0)).collect(Collectors.toList()); + final List> tensors = buildVectorFromResponse(mlOutput); + final List scores = tensors.stream().map(v -> v.get(0)).collect(Collectors.toList()); listener.onResponse(scores); }, e -> { if (RetryUtil.shouldRetry(e, retryTime)) { @@ -224,14 +225,14 @@ private MLInput createMLTextPairsInput(final String query, final List in return new MLInput(FunctionName.TEXT_SIMILARITY, null, inputDataset); } - private List> buildVectorFromResponse(MLOutput mlOutput) { - final List> vector = new ArrayList<>(); + private List> buildVectorFromResponse(MLOutput mlOutput) { + final List> vector = new ArrayList<>(); final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput; final List tensorOutputList = modelTensorOutput.getMlModelOutputs(); for (final ModelTensors tensors : tensorOutputList) { final List tensorsList = tensors.getMlModelTensors(); for (final ModelTensor tensor : tensorsList) { - vector.add(Arrays.stream(tensor.getData()).map(value -> (Float) value).collect(Collectors.toList())); + vector.add(Arrays.stream(tensor.getData()).map(value -> (T) value).collect(Collectors.toList())); } } return vector; @@ -255,8 +256,8 @@ private List> buildVectorFromResponse(MLOutput mlOutput) { return resultMaps; } - private List buildSingleVectorFromResponse(final MLOutput mlOutput) { - final List> vector = buildVectorFromResponse(mlOutput); + private List buildSingleVectorFromResponse(final MLOutput mlOutput) { + final List> vector = buildVectorFromResponse(mlOutput); return vector.isEmpty() ? new ArrayList<>() : vector.get(0); } @@ -265,11 +266,11 @@ private void retryableInferenceSentencesWithSingleVectorResult( final String modelId, final Map inputObjects, final int retryTime, - final ActionListener> listener + final ActionListener> listener ) { MLInput mlInput = createMLMultimodalInput(targetResponseFilters, inputObjects); mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { - final List vector = buildSingleVectorFromResponse(mlOutput); + final List vector = buildSingleVectorFromResponse(mlOutput); log.debug("Inference Response for input sentence is : {} ", vector); listener.onResponse(vector); }, e -> { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java index e808869f9..514216710 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java @@ -124,7 +124,7 @@ public void execute(final IngestDocument ingestDocument, final BiConsumer vectors) { + private void setVectorFieldsToDocument(final IngestDocument ingestDocument, final List vectors) { Objects.requireNonNull(vectors, "embedding failed, inference returns null result!"); log.debug("Text embedding result fetched, starting build vector output!"); Map textEmbeddingResult = buildTextEmbeddingResult(this.embedding, vectors); @@ -164,7 +164,7 @@ Map buildMapWithKnnKeyAndOriginalValue(final IngestDocument inge @SuppressWarnings({ "unchecked" }) @VisibleForTesting - Map buildTextEmbeddingResult(final String knnKey, List modelTensorList) { + Map buildTextEmbeddingResult(final String knnKey, List modelTensorList) { Map result = new LinkedHashMap<>(); result.put(knnKey, modelTensorList); return result; diff --git a/src/test/java/org/opensearch/neuralsearch/common/VectorUtilTests.java b/src/test/java/org/opensearch/neuralsearch/common/VectorUtilTests.java index a06e8f84d..4ebb7858f 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/VectorUtilTests.java +++ b/src/test/java/org/opensearch/neuralsearch/common/VectorUtilTests.java @@ -12,15 +12,15 @@ public class VectorUtilTests extends OpenSearchTestCase { public void testVectorAsListToArray() { - List vectorAsList_withThreeElements = List.of(1.3f, 2.5f, 3.5f); + List vectorAsList_withThreeElements = List.of(1.3f, 2.5f, 3.5f); float[] vectorAsArray_withThreeElements = VectorUtil.vectorAsListToArray(vectorAsList_withThreeElements); assertEquals(vectorAsList_withThreeElements.size(), vectorAsArray_withThreeElements.length); for (int i = 0; i < vectorAsList_withThreeElements.size(); i++) { - assertEquals(vectorAsList_withThreeElements.get(i), vectorAsArray_withThreeElements[i], 0.0f); + assertEquals(vectorAsList_withThreeElements.get(i).floatValue(), vectorAsArray_withThreeElements[i], 0.0f); } - List vectorAsList_withNoElements = Collections.emptyList(); + List vectorAsList_withNoElements = Collections.emptyList(); float[] vectorAsArray_withNoElements = VectorUtil.vectorAsListToArray(vectorAsList_withNoElements); assertEquals(0, vectorAsArray_withNoElements.length); } diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index 3749e63dc..552ac1249 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -36,10 +36,13 @@ public class MLCommonsClientAccessorTests extends OpenSearchTestCase { @Mock - private ActionListener>> resultListener; + private ActionListener>> resultListener; @Mock - private ActionListener> singleSentenceResultListener; + private ActionListener> singleSentenceResultListener; + + @Mock + private ActionListener> similarityResultListener; @Mock private MachineLearningNodeClient client; @@ -53,7 +56,7 @@ public void setup() { } public void testInferenceSentence_whenValidInput_thenSuccess() { - final List vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY)); + final List vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY)); Mockito.doAnswer(invocation -> { final ActionListener actionListener = invocation.getArgument(2); actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY)); @@ -69,7 +72,7 @@ public void testInferenceSentence_whenValidInput_thenSuccess() { } public void testInferenceSentences_whenValidInputThenSuccess() { - final List> vectorList = new ArrayList<>(); + final List> vectorList = new ArrayList<>(); vectorList.add(Arrays.asList(TestCommonConstants.PREDICT_VECTOR_ARRAY)); Mockito.doAnswer(invocation -> { final ActionListener actionListener = invocation.getArgument(2); @@ -85,7 +88,7 @@ public void testInferenceSentences_whenValidInputThenSuccess() { } public void testInferenceSentences_whenResultFromClient_thenEmptyVectorList() { - final List> vectorList = new ArrayList<>(); + final List> vectorList = new ArrayList<>(); vectorList.add(Collections.emptyList()); Mockito.doAnswer(invocation -> { final ActionListener actionListener = invocation.getArgument(2); @@ -278,7 +281,7 @@ public void testInferenceSentencesWithMapResult_whenNotRetryableException_thenFa } public void testInferenceMultimodal_whenValidInput_thenSuccess() { - final List vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY)); + final List vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY)); Mockito.doAnswer(invocation -> { final ActionListener actionListener = invocation.getArgument(2); actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY)); @@ -337,13 +340,13 @@ public void testInferenceSimilarity_whenValidInput_thenSuccess() { TestCommonConstants.MODEL_ID, "is it sunny", List.of("it is sunny today", "roses are red"), - singleSentenceResultListener + similarityResultListener ); Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); - Mockito.verify(singleSentenceResultListener).onResponse(vector); - Mockito.verifyNoMoreInteractions(singleSentenceResultListener); + Mockito.verify(similarityResultListener).onResponse(vector); + Mockito.verifyNoMoreInteractions(similarityResultListener); } public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() { @@ -358,13 +361,13 @@ public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() { TestCommonConstants.MODEL_ID, "is it sunny", List.of("it is sunny today", "roses are red"), - singleSentenceResultListener + similarityResultListener ); Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); - Mockito.verify(singleSentenceResultListener).onFailure(exception); - Mockito.verifyNoMoreInteractions(singleSentenceResultListener); + Mockito.verify(similarityResultListener).onFailure(exception); + Mockito.verifyNoMoreInteractions(similarityResultListener); } public void testInferenceSimilarity_whenNodeNotConnectedException_ThenTryThreeTimes() { @@ -382,12 +385,12 @@ public void testInferenceSimilarity_whenNodeNotConnectedException_ThenTryThreeTi TestCommonConstants.MODEL_ID, "is it sunny", List.of("it is sunny today", "roses are red"), - singleSentenceResultListener + similarityResultListener ); Mockito.verify(client, times(4)) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); - Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException); + Mockito.verify(similarityResultListener).onFailure(nodeNodeConnectedException); } private ModelTensorOutput createModelTensorOutput(final Float[] output) { diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java index 6d8e810f3..dbf05144b 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java @@ -646,10 +646,10 @@ public void testHashAndEquals() { @SneakyThrows public void testRewrite_whenVectorSupplierNull_thenSetVectorSupplier() { NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME).queryText(QUERY_TEXT).modelId(MODEL_ID).k(K); - List expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f); + List expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f); MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class); doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(2); + ActionListener> listener = invocation.getArgument(2); listener.onResponse(expectedVector); return null; }).when(mlCommonsClientAccessor).inferenceSentences(any(), anyMap(), any()); @@ -682,10 +682,10 @@ public void testRewrite_whenVectorSupplierNullAndQueryTextAndImageTextSet_thenSe .queryImage(IMAGE_TEXT) .modelId(MODEL_ID) .k(K); - List expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f); + List expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f); MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class); doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(2); + ActionListener> listener = invocation.getArgument(2); listener.onResponse(expectedVector); return null; }).when(mlCommonsClientAccessor).inferenceSentences(any(), anyMap(), any()); diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index afc545447..8571189bc 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -268,7 +268,7 @@ protected float[] runInference(final String modelId, final String queryText) { List output = (List) result.get("output"); assertEquals(1, output.size()); Map map = (Map) output.get(0); - List data = ((List) map.get("data")).stream().map(Double::floatValue).collect(Collectors.toList()); + List data = ((List) map.get("data")).stream().map(Double::floatValue).collect(Collectors.toList()); return vectorAsListToArray(data); }