Skip to content

Commit

Permalink
Support different embedding types of model response
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Dec 9, 2024
1 parent 3c7f275 commit 56fc80a
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ public class VectorUtil {
* @param vectorAsList {@link List} of {@link Float}'s representing the vector
* @return array of floats produced from input list
*/
public static float[] vectorAsListToArray(List<Float> vectorAsList) {
public static float[] vectorAsListToArray(List<Number> vectorAsList) {
float[] vector = new float[vectorAsList.size()];
for (int i = 0; i < vectorAsList.size(); i++) {
vector[i] = vectorAsList.get(i);
vector[i] = vectorAsList.get(i).floatValue();
}
return vector;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public class MLCommonsClientAccessor {
public void inferenceSentence(
@NonNull final String modelId,
@NonNull final String inputText,
@NonNull final ActionListener<List<Float>> listener
@NonNull final ActionListener<List<Number>> listener
) {
inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, List.of(inputText), ActionListener.wrap(response -> {
if (response.size() != 1) {
Expand Down Expand Up @@ -82,7 +82,7 @@ public void inferenceSentence(
public void inferenceSentences(
@NonNull final String modelId,
@NonNull final List<String> inputText,
@NonNull final ActionListener<List<List<Float>>> listener
@NonNull final ActionListener<List<List<Number>>> listener
) {
inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, inputText, listener);
}
Expand All @@ -103,7 +103,7 @@ public void inferenceSentences(
@NonNull final List<String> targetResponseFilters,
@NonNull final String modelId,
@NonNull final List<String> inputText,
@NonNull final ActionListener<List<List<Float>>> listener
@NonNull final ActionListener<List<List<Number>>> listener
) {
retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, 0, listener);
}
Expand All @@ -128,7 +128,7 @@ public void inferenceSentencesWithMapResult(
public void inferenceSentences(
@NonNull final String modelId,
@NonNull final Map<String, String> inputObjects,
@NonNull final ActionListener<List<Float>> listener
@NonNull final ActionListener<List<Number>> listener
) {
retryableInferenceSentencesWithSingleVectorResult(TARGET_RESPONSE_FILTERS, modelId, inputObjects, 0, listener);
}
Expand Down Expand Up @@ -177,11 +177,11 @@ private void retryableInferenceSentencesWithVectorResult(
final String modelId,
final List<String> inputText,
final int retryTime,
final ActionListener<List<List<Float>>> listener
final ActionListener<List<List<Number>>> listener
) {
MLInput mlInput = createMLTextInput(targetResponseFilters, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);
final List<List<Number>> vector = buildVectorFromResponse(mlOutput);
listener.onResponse(vector);
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
Expand All @@ -202,7 +202,8 @@ private void retryableInferenceSimilarityWithVectorResult(
) {
MLInput mlInput = createMLTextPairsInput(queryText, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<Float> scores = buildVectorFromResponse(mlOutput).stream().map(v -> v.get(0)).collect(Collectors.toList());
final List<List<Float>> tensors = buildVectorFromResponse(mlOutput);
final List<Float> scores = tensors.stream().map(v -> v.get(0)).collect(Collectors.toList());
listener.onResponse(scores);
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
Expand All @@ -224,14 +225,14 @@ private MLInput createMLTextPairsInput(final String query, final List<String> in
return new MLInput(FunctionName.TEXT_SIMILARITY, null, inputDataset);
}

private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
final List<List<Float>> vector = new ArrayList<>();
private <T extends Number> List<List<T>> buildVectorFromResponse(MLOutput mlOutput) {
final List<List<T>> vector = new ArrayList<>();
final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput;
final List<ModelTensors> tensorOutputList = modelTensorOutput.getMlModelOutputs();
for (final ModelTensors tensors : tensorOutputList) {
final List<ModelTensor> tensorsList = tensors.getMlModelTensors();
for (final ModelTensor tensor : tensorsList) {
vector.add(Arrays.stream(tensor.getData()).map(value -> (Float) value).collect(Collectors.toList()));
vector.add(Arrays.stream(tensor.getData()).map(value -> (T) value).collect(Collectors.toList()));
}
}
return vector;
Expand All @@ -255,8 +256,8 @@ private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
return resultMaps;
}

private List<Float> buildSingleVectorFromResponse(final MLOutput mlOutput) {
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);
private <T extends Number> List<T> buildSingleVectorFromResponse(final MLOutput mlOutput) {
final List<List<T>> vector = buildVectorFromResponse(mlOutput);
return vector.isEmpty() ? new ArrayList<>() : vector.get(0);
}

Expand All @@ -265,11 +266,11 @@ private void retryableInferenceSentencesWithSingleVectorResult(
final String modelId,
final Map<String, String> inputObjects,
final int retryTime,
final ActionListener<List<Float>> listener
final ActionListener<List<Number>> listener
) {
MLInput mlInput = createMLMultimodalInput(targetResponseFilters, inputObjects);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<Float> vector = buildSingleVectorFromResponse(mlOutput);
final List<Number> vector = buildSingleVectorFromResponse(mlOutput);
log.debug("Inference Response for input sentence is : {} ", vector);
listener.onResponse(vector);
}, e -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ public void execute(final IngestDocument ingestDocument, final BiConsumer<Ingest

}

private void setVectorFieldsToDocument(final IngestDocument ingestDocument, final List<Float> vectors) {
private void setVectorFieldsToDocument(final IngestDocument ingestDocument, final List<Number> vectors) {
Objects.requireNonNull(vectors, "embedding failed, inference returns null result!");
log.debug("Text embedding result fetched, starting build vector output!");
Map<String, Object> textEmbeddingResult = buildTextEmbeddingResult(this.embedding, vectors);
Expand Down Expand Up @@ -164,7 +164,7 @@ Map<String, String> buildMapWithKnnKeyAndOriginalValue(final IngestDocument inge

@SuppressWarnings({ "unchecked" })
@VisibleForTesting
Map<String, Object> buildTextEmbeddingResult(final String knnKey, List<Float> modelTensorList) {
Map<String, Object> buildTextEmbeddingResult(final String knnKey, List<Number> modelTensorList) {
Map<String, Object> result = new LinkedHashMap<>();
result.put(knnKey, modelTensorList);
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
public class VectorUtilTests extends OpenSearchTestCase {

public void testVectorAsListToArray() {
List<Float> vectorAsList_withThreeElements = List.of(1.3f, 2.5f, 3.5f);
List<Number> vectorAsList_withThreeElements = List.of(1.3f, 2.5f, 3.5f);
float[] vectorAsArray_withThreeElements = VectorUtil.vectorAsListToArray(vectorAsList_withThreeElements);

assertEquals(vectorAsList_withThreeElements.size(), vectorAsArray_withThreeElements.length);
for (int i = 0; i < vectorAsList_withThreeElements.size(); i++) {
assertEquals(vectorAsList_withThreeElements.get(i), vectorAsArray_withThreeElements[i], 0.0f);
assertEquals(vectorAsList_withThreeElements.get(i).floatValue(), vectorAsArray_withThreeElements[i], 0.0f);
}

List<Float> vectorAsList_withNoElements = Collections.emptyList();
List<Number> vectorAsList_withNoElements = Collections.emptyList();
float[] vectorAsArray_withNoElements = VectorUtil.vectorAsListToArray(vectorAsList_withNoElements);
assertEquals(0, vectorAsArray_withNoElements.length);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,13 @@
public class MLCommonsClientAccessorTests extends OpenSearchTestCase {

@Mock
private ActionListener<List<List<Float>>> resultListener;
private ActionListener<List<List<Number>>> resultListener;

@Mock
private ActionListener<List<Float>> singleSentenceResultListener;
private ActionListener<List<Number>> singleSentenceResultListener;

@Mock
private ActionListener<List<Float>> similarityResultListener;

@Mock
private MachineLearningNodeClient client;
Expand All @@ -53,7 +56,7 @@ public void setup() {
}

public void testInferenceSentence_whenValidInput_thenSuccess() {
final List<Float> vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY));
final List<Number> vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY));
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY));
Expand All @@ -69,7 +72,7 @@ public void testInferenceSentence_whenValidInput_thenSuccess() {
}

public void testInferenceSentences_whenValidInputThenSuccess() {
final List<List<Float>> vectorList = new ArrayList<>();
final List<List<Number>> vectorList = new ArrayList<>();
vectorList.add(Arrays.asList(TestCommonConstants.PREDICT_VECTOR_ARRAY));
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
Expand All @@ -85,7 +88,7 @@ public void testInferenceSentences_whenValidInputThenSuccess() {
}

public void testInferenceSentences_whenResultFromClient_thenEmptyVectorList() {
final List<List<Float>> vectorList = new ArrayList<>();
final List<List<Number>> vectorList = new ArrayList<>();
vectorList.add(Collections.emptyList());
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
Expand Down Expand Up @@ -278,7 +281,7 @@ public void testInferenceSentencesWithMapResult_whenNotRetryableException_thenFa
}

public void testInferenceMultimodal_whenValidInput_thenSuccess() {
final List<Float> vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY));
final List<Number> vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY));
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY));
Expand Down Expand Up @@ -337,13 +340,13 @@ public void testInferenceSimilarity_whenValidInput_thenSuccess() {
TestCommonConstants.MODEL_ID,
"is it sunny",
List.of("it is sunny today", "roses are red"),
singleSentenceResultListener
similarityResultListener
);

Mockito.verify(client)
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(singleSentenceResultListener).onResponse(vector);
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
Mockito.verify(similarityResultListener).onResponse(vector);
Mockito.verifyNoMoreInteractions(similarityResultListener);
}

public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() {
Expand All @@ -358,13 +361,13 @@ public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() {
TestCommonConstants.MODEL_ID,
"is it sunny",
List.of("it is sunny today", "roses are red"),
singleSentenceResultListener
similarityResultListener
);

Mockito.verify(client)
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(singleSentenceResultListener).onFailure(exception);
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
Mockito.verify(similarityResultListener).onFailure(exception);
Mockito.verifyNoMoreInteractions(similarityResultListener);
}

public void testInferenceSimilarity_whenNodeNotConnectedException_ThenTryThreeTimes() {
Expand All @@ -382,12 +385,12 @@ public void testInferenceSimilarity_whenNodeNotConnectedException_ThenTryThreeTi
TestCommonConstants.MODEL_ID,
"is it sunny",
List.of("it is sunny today", "roses are red"),
singleSentenceResultListener
similarityResultListener
);

Mockito.verify(client, times(4))
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException);
Mockito.verify(similarityResultListener).onFailure(nodeNodeConnectedException);
}

private ModelTensorOutput createModelTensorOutput(final Float[] output) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -646,10 +646,10 @@ public void testHashAndEquals() {
@SneakyThrows
public void testRewrite_whenVectorSupplierNull_thenSetVectorSupplier() {
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME).queryText(QUERY_TEXT).modelId(MODEL_ID).k(K);
List<Float> expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f);
List<Number> expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f);
MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class);
doAnswer(invocation -> {
ActionListener<List<Float>> listener = invocation.getArgument(2);
ActionListener<List<Number>> listener = invocation.getArgument(2);
listener.onResponse(expectedVector);
return null;
}).when(mlCommonsClientAccessor).inferenceSentences(any(), anyMap(), any());
Expand Down Expand Up @@ -682,10 +682,10 @@ public void testRewrite_whenVectorSupplierNullAndQueryTextAndImageTextSet_thenSe
.queryImage(IMAGE_TEXT)
.modelId(MODEL_ID)
.k(K);
List<Float> expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f);
List<Number> expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f);
MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class);
doAnswer(invocation -> {
ActionListener<List<Float>> listener = invocation.getArgument(2);
ActionListener<List<Number>> listener = invocation.getArgument(2);
listener.onResponse(expectedVector);
return null;
}).when(mlCommonsClientAccessor).inferenceSentences(any(), anyMap(), any());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ protected float[] runInference(final String modelId, final String queryText) {
List<Object> output = (List<Object>) result.get("output");
assertEquals(1, output.size());
Map<String, Object> map = (Map<String, Object>) output.get(0);
List<Float> data = ((List<Double>) map.get("data")).stream().map(Double::floatValue).collect(Collectors.toList());
List<Number> data = ((List<Double>) map.get("data")).stream().map(Double::floatValue).collect(Collectors.toList());
return vectorAsListToArray(data);
}

Expand Down

0 comments on commit 56fc80a

Please sign in to comment.