From f71e5ee1f623a53d317b3d8c95458f66854ee540 Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Tue, 11 Jul 2023 21:18:55 -0500 Subject: [PATCH] Add Support for Lucene Byte Sized Vector (#971) * Add Indexing Support for Lucene Byte Sized Vector (#937) * Add Indexing Support for Lucene Byte Sized Vector Signed-off-by: Naveen Tatikonda * Add tests for Indexing Signed-off-by: Naveen Tatikonda * Add CHANGELOG Signed-off-by: Naveen Tatikonda * Address Review Comments Signed-off-by: Naveen Tatikonda --------- Signed-off-by: Naveen Tatikonda * Add Querying Support to Lucene Byte Sized Vector (#956) * Add Querying Support to Lucene Byte Sized Vector Signed-off-by: Naveen Tatikonda * Add CHANGELOG Signed-off-by: Naveen Tatikonda * Address Review Comments Signed-off-by: Naveen Tatikonda --------- Signed-off-by: Naveen Tatikonda * Add DocValues Support for Lucene Byte Sized Vector (#953) Signed-off-by: Naveen Tatikonda * Update Release Notes Signed-off-by: Naveen Tatikonda --------- Signed-off-by: Naveen Tatikonda (cherry picked from commit bf04854c483fdfd38663bdc490e1730c994bda6d) Signed-off-by: Naveen Tatikonda --- .../opensearch-knn.release-notes-2.9.0.0.md | 3 +- .../opensearch/knn/common/KNNConstants.java | 5 + .../knn/index/KNNVectorDVLeafFieldData.java | 6 +- .../knn/index/KNNVectorIndexFieldData.java | 12 +- .../knn/index/KNNVectorScriptDocValues.java | 22 +- .../opensearch/knn/index/VectorDataType.java | 120 ++++ .../org/opensearch/knn/index/VectorField.java | 15 + .../index/mapper/KNNVectorFieldMapper.java | 171 ++++-- .../mapper/KNNVectorFieldMapperUtil.java | 164 ++++++ .../knn/index/mapper/LuceneFieldMapper.java | 65 ++- .../knn/index/query/KNNQueryBuilder.java | 17 +- .../knn/index/query/KNNQueryFactory.java | 68 ++- .../knn/plugin/script/KNNScoringSpace.java | 30 +- .../plugin/script/KNNScoringSpaceUtil.java | 16 +- .../knn/plugin/script/KNNScoringUtil.java | 27 +- .../index/KNNVectorDVLeafFieldDataTests.java | 18 +- .../index/KNNVectorIndexFieldDataTests.java | 2 +- .../index/KNNVectorScriptDocValuesTests.java | 3 +- .../opensearch/knn/index/LuceneEngineIT.java | 111 ++-- .../knn/index/VectorDataTypeIT.java | 545 ++++++++++++++++++ .../knn/index/VectorDataTypeTests.java | 109 ++++ .../knn/index/codec/KNNCodecTestCase.java | 19 +- .../mapper/KNNVectorFieldMapperTests.java | 196 ++++++- .../knn/index/query/KNNQueryBuilderTests.java | 4 + .../knn/index/query/KNNQueryFactoryTests.java | 20 +- .../script/KNNScoringSpaceUtilTests.java | 10 +- .../plugin/script/KNNScoringUtilTests.java | 9 +- 27 files changed, 1598 insertions(+), 189 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/VectorDataType.java create mode 100644 src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java create mode 100644 src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java create mode 100644 src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java diff --git a/release-notes/opensearch-knn.release-notes-2.9.0.0.md b/release-notes/opensearch-knn.release-notes-2.9.0.0.md index 10e9cbdda..0ea90d037 100644 --- a/release-notes/opensearch-knn.release-notes-2.9.0.0.md +++ b/release-notes/opensearch-knn.release-notes-2.9.0.0.md @@ -3,4 +3,5 @@ Compatible with OpenSearch 2.9.0 ### Features -Added support for Efficient Pre-filtering for Faiss Engine ([#936](https://github.com/opensearch-project/k-NN/pull/936)) +* Added support for Efficient Pre-filtering for Faiss Engine ([#936](https://github.com/opensearch-project/k-NN/pull/936)) +* Add Support for Lucene Byte Sized Vector ([#971](https://github.com/opensearch-project/k-NN/pull/971)) diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 87d7a1c21..6d387eec4 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -5,6 +5,8 @@ package org.opensearch.knn.common; +import org.opensearch.knn.index.VectorDataType; + public class KNNConstants { // shared across library constants public static final String DIMENSION = "dimension"; @@ -50,6 +52,9 @@ public class KNNConstants { public static final String MAX_VECTOR_COUNT_PARAMETER = "max_training_vector_count"; public static final String SEARCH_SIZE_PARAMETER = "search_size"; + public static final String VECTOR_DATA_TYPE_FIELD = "data_type"; + public static final VectorDataType DEFAULT_VECTOR_DATA_TYPE_FIELD = VectorDataType.FLOAT; + // Lucene specific constants public static final String LUCENE_NAME = "lucene"; diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java b/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java index 5f522e3de..f4caa4f20 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java @@ -18,10 +18,12 @@ public class KNNVectorDVLeafFieldData implements LeafFieldData { private final LeafReader reader; private final String fieldName; + private final VectorDataType vectorDataType; - public KNNVectorDVLeafFieldData(LeafReader reader, String fieldName) { + public KNNVectorDVLeafFieldData(LeafReader reader, String fieldName, VectorDataType vectorDataType) { this.reader = reader; this.fieldName = fieldName; + this.vectorDataType = vectorDataType; } @Override @@ -38,7 +40,7 @@ public long ramBytesUsed() { public ScriptDocValues getScriptValues() { try { BinaryDocValues values = DocValues.getBinary(reader, fieldName); - return new KNNVectorScriptDocValues(values, fieldName); + return new KNNVectorScriptDocValues(values, fieldName, vectorDataType); } catch (IOException e) { throw new IllegalStateException("Cannot load doc values for knn vector field: " + fieldName, e); } diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorIndexFieldData.java b/src/main/java/org/opensearch/knn/index/KNNVectorIndexFieldData.java index 367cfae53..deef8bae1 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorIndexFieldData.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorIndexFieldData.java @@ -21,10 +21,12 @@ public class KNNVectorIndexFieldData implements IndexFieldData build(IndexFieldDataCache cache, CircuitBreakerService breakerService) { - return new KNNVectorIndexFieldData(name, valuesSourceType); + return new KNNVectorIndexFieldData(name, valuesSourceType, vectorDataType); } } } diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java index 0c8240dd4..9f7d52205 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java @@ -5,26 +5,22 @@ package org.opensearch.knn.index; +import lombok.Getter; +import lombok.RequiredArgsConstructor; import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.util.BytesRef; import org.opensearch.ExceptionsHelper; import org.opensearch.index.fielddata.ScriptDocValues; -import org.opensearch.knn.index.codec.util.KNNVectorSerializer; -import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; -import java.io.ByteArrayInputStream; import java.io.IOException; +@RequiredArgsConstructor public final class KNNVectorScriptDocValues extends ScriptDocValues { private final BinaryDocValues binaryDocValues; private final String fieldName; - private boolean docExists; - - public KNNVectorScriptDocValues(BinaryDocValues binaryDocValues, String fieldName) { - this.binaryDocValues = binaryDocValues; - this.fieldName = fieldName; - } + @Getter + private final VectorDataType vectorDataType; + private boolean docExists = false; @Override public void setNextDocId(int docId) throws IOException { @@ -47,11 +43,7 @@ public float[] getValue() { throw new IllegalStateException(errorMessage); } try { - BytesRef value = binaryDocValues.binaryValue(); - ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length); - final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream); - final float[] vector = vectorSerializer.byteToFloatArray(byteStream); - return vector; + return vectorDataType.getVectorFromDocValues(binaryDocValues.binaryValue()); } catch (IOException e) { throw ExceptionsHelper.convertToOpenSearchException(e); } diff --git a/src/main/java/org/opensearch/knn/index/VectorDataType.java b/src/main/java/org/opensearch/knn/index/VectorDataType.java new file mode 100644 index 000000000..23b374e9d --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/VectorDataType.java @@ -0,0 +1,120 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import org.apache.lucene.document.FieldType; +import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.document.KnnVectorField; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.BytesRef; +import org.opensearch.knn.index.codec.util.KNNVectorSerializer; +import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; + +import java.io.ByteArrayInputStream; +import java.util.Arrays; +import java.util.Locale; +import java.util.Objects; +import java.util.stream.Collectors; + +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; + +/** + * Enum contains data_type of vectors and right now only supported for lucene engine in k-NN plugin. + * We have two vector data_types, one is float (default) and the other one is byte. + */ +@AllArgsConstructor +public enum VectorDataType { + BYTE("byte") { + + @Override + public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) { + return KnnByteVectorField.createFieldType(dimension, vectorSimilarityFunction); + } + + @Override + public float[] getVectorFromDocValues(BytesRef binaryValue) { + float[] vector = new float[binaryValue.length]; + int i = 0; + int j = binaryValue.offset; + + while (i < binaryValue.length) { + vector[i++] = binaryValue.bytes[j++]; + } + return vector; + } + }, + FLOAT("float") { + + @Override + public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) { + return KnnVectorField.createFieldType(dimension, vectorSimilarityFunction); + } + + @Override + public float[] getVectorFromDocValues(BytesRef binaryValue) { + ByteArrayInputStream byteStream = new ByteArrayInputStream(binaryValue.bytes, binaryValue.offset, binaryValue.length); + final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream); + return vectorSerializer.byteToFloatArray(byteStream); + } + + }; + + public static final String SUPPORTED_VECTOR_DATA_TYPES = Arrays.stream(VectorDataType.values()) + .map(VectorDataType::getValue) + .collect(Collectors.joining(",")); + @Getter + private final String value; + + /** + * Creates a KnnVectorFieldType based on the VectorDataType using the provided dimension and + * VectorSimilarityFunction. + * + * @param dimension Dimension of the vector + * @param vectorSimilarityFunction VectorSimilarityFunction for a given spaceType + * @return FieldType + */ + public abstract FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction); + + /** + * Deserializes float vector from doc values binary value. + * + * @param binaryValue Binary Value of DocValues + * @return float vector deserialized from binary value + */ + public abstract float[] getVectorFromDocValues(BytesRef binaryValue); + + /** + * Validates if given VectorDataType is in the list of supported data types. + * @param vectorDataType VectorDataType + * @return the same VectorDataType if it is in the supported values + * throws Exception if an invalid value is provided. + */ + public static VectorDataType get(String vectorDataType) { + Objects.requireNonNull( + vectorDataType, + String.format( + Locale.ROOT, + "[%s] should not be null. Supported types are [%s]", + VECTOR_DATA_TYPE_FIELD, + SUPPORTED_VECTOR_DATA_TYPES + ) + ); + try { + return VectorDataType.valueOf(vectorDataType.toUpperCase(Locale.ROOT)); + } catch (Exception e) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "Invalid value provided for [%s] field. Supported values are [%s]", + VECTOR_DATA_TYPE_FIELD, + SUPPORTED_VECTOR_DATA_TYPES + ) + ); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/VectorField.java b/src/main/java/org/opensearch/knn/index/VectorField.java index c88630f6c..f28ef6238 100644 --- a/src/main/java/org/opensearch/knn/index/VectorField.java +++ b/src/main/java/org/opensearch/knn/index/VectorField.java @@ -23,4 +23,19 @@ public VectorField(String name, float[] value, IndexableFieldType type) { throw new RuntimeException(e); } } + + /** + * @param name FieldType name + * @param value an array of byte vector values + * @param type FieldType to build DocValues + */ + public VectorField(String name, byte[] value, IndexableFieldType type) { + super(name, new BytesRef(), type); + try { + this.setBytesValue(value); + } catch (Exception e) { + throw new RuntimeException(e); + } + + } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index ab45c384f..346d4c238 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -11,7 +11,6 @@ import org.opensearch.knn.common.KNNConstants; import org.apache.lucene.document.FieldType; -import org.apache.lucene.document.StoredField; import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.IndexOptions; import org.apache.lucene.search.DocValuesFieldExistsQuery; @@ -35,6 +34,7 @@ import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.KNNVectorIndexFieldData; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.VectorField; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelDao; @@ -45,11 +45,21 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Optional; import java.util.function.Supplier; +import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.index.KNNSettings.KNN_INDEX; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithEngine; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithKnnIndexSetting; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVectorValue; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFloatVectorValue; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDimension; /** * Field Mapper for KNN vector type. @@ -96,6 +106,18 @@ public static class Builder extends ParametrizedFieldMapper.Builder { return value; }, m -> toType(m).dimension); + /** + * data_type which defines the datatype of the vector values. This is an optional parameter and + * this is right now only relevant for lucene engine. The default value is float. + */ + private final Parameter vectorDataType = new Parameter<>( + VECTOR_DATA_TYPE_FIELD, + false, + () -> DEFAULT_VECTOR_DATA_TYPE_FIELD, + (n, c, o) -> VectorDataType.get((String) o), + m -> toType(m).vectorDataType + ); + /** * modelId provides a way for a user to generate the underlying library indices from an already serialized * model template index. If this parameter is set, it will take precedence. This parameter is only relevant for @@ -168,7 +190,7 @@ public Builder(String name, String spaceType, String m, String efConstruction) { @Override protected List> getParameters() { - return Arrays.asList(stored, hasDocValues, dimension, meta, knnMethodContext, modelId); + return Arrays.asList(stored, hasDocValues, dimension, vectorDataType, meta, knnMethodContext, modelId); } protected Explicit ignoreMalformed(BuilderContext context) { @@ -203,7 +225,8 @@ public KNNVectorFieldMapper build(BuilderContext context) { buildFullName(context), metaValue, dimension.getValue(), - knnMethodContext + knnMethodContext, + vectorDataType.getValue() ); if (knnMethodContext.getKnnEngine() == KNNEngine.LUCENE) { log.debug(String.format("Use [LuceneFieldMapper] mapper for field [%s]", name)); @@ -216,10 +239,17 @@ public KNNVectorFieldMapper build(BuilderContext context) { .ignoreMalformed(ignoreMalformed) .stored(stored.get()) .hasDocValues(hasDocValues.get()) + .vectorDataType(vectorDataType.getValue()) .knnMethodContext(knnMethodContext) .build(); return new LuceneFieldMapper(createLuceneFieldMapperInput); } + + // Validates and throws exception if data_type field is set in the index mapping + // using any VectorDataType (other than float, which is default) because other + // VectorDataTypes are only supported for lucene engine. + validateVectorDataTypeWithEngine(vectorDataType); + return new MethodFieldMapper( name, mappedFieldType, @@ -265,9 +295,14 @@ public KNNVectorFieldMapper build(BuilderContext context) { this.efConstruction = LegacyFieldMapper.getEfConstruction(context.indexSettings()); } + // Validates and throws exception if index.knn is set to true in the index settings + // using any VectorDataType (other than float, which is default) because we are using NMSLIB engine for LegacyFieldMapper + // and it only supports float VectorDataType + validateVectorDataTypeWithKnnIndexSetting(context.indexSettings().getAsBoolean(KNN_INDEX, false), vectorDataType); + return new LegacyFieldMapper( name, - new KNNVectorFieldType(buildFullName(context), metaValue, dimension.getValue()), + new KNNVectorFieldType(buildFullName(context), metaValue, dimension.getValue(), vectorDataType.getValue()), multiFieldsBuilder, copyToBuilder, ignoreMalformed, @@ -336,20 +371,43 @@ public static class KNNVectorFieldType extends MappedFieldType { int dimension; String modelId; KNNMethodContext knnMethodContext; + VectorDataType vectorDataType; - public KNNVectorFieldType(String name, Map meta, int dimension) { - this(name, meta, dimension, null, null); + public KNNVectorFieldType(String name, Map meta, int dimension, VectorDataType vectorDataType) { + this(name, meta, dimension, null, null, vectorDataType); } public KNNVectorFieldType(String name, Map meta, int dimension, KNNMethodContext knnMethodContext) { - this(name, meta, dimension, knnMethodContext, null); + this(name, meta, dimension, knnMethodContext, null, DEFAULT_VECTOR_DATA_TYPE_FIELD); } public KNNVectorFieldType(String name, Map meta, int dimension, KNNMethodContext knnMethodContext, String modelId) { + this(name, meta, dimension, knnMethodContext, modelId, DEFAULT_VECTOR_DATA_TYPE_FIELD); + } + + public KNNVectorFieldType( + String name, + Map meta, + int dimension, + KNNMethodContext knnMethodContext, + VectorDataType vectorDataType + ) { + this(name, meta, dimension, knnMethodContext, null, vectorDataType); + } + + public KNNVectorFieldType( + String name, + Map meta, + int dimension, + KNNMethodContext knnMethodContext, + String modelId, + VectorDataType vectorDataType + ) { super(name, false, false, true, TextSearchInfo.NONE, meta); this.dimension = dimension; this.modelId = modelId; this.knnMethodContext = knnMethodContext; + this.vectorDataType = vectorDataType; } @Override @@ -378,7 +436,7 @@ public Query termQuery(Object value, QueryShardContext context) { @Override public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName, Supplier searchLookup) { failIfNoDocValues(); - return new KNNVectorIndexFieldData.Builder(name(), CoreValuesSourceType.BYTES); + return new KNNVectorIndexFieldData.Builder(name(), CoreValuesSourceType.BYTES, this.vectorDataType); } } @@ -386,6 +444,7 @@ public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName, S protected boolean stored; protected boolean hasDocValues; protected Integer dimension; + protected VectorDataType vectorDataType; protected ModelDao modelDao; // These members map to parameters in the builder. They need to be declared in the abstract class due to the @@ -408,6 +467,7 @@ public KNNVectorFieldMapper( this.stored = stored; this.hasDocValues = hasDocValues; this.dimension = mappedFieldType.getDimension(); + this.vectorDataType = mappedFieldType.getVectorDataType(); updateEngineStats(); } @@ -430,18 +490,34 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx validateIfKNNPluginEnabled(); validateIfCircuitBreakerIsNotTriggered(); - Optional arrayOptional = getFloatsFromContext(context, dimension); + if (VectorDataType.BYTE == vectorDataType) { + Optional bytesArrayOptional = getBytesFromContext(context, dimension); - if (!arrayOptional.isPresent()) { - return; - } - final float[] array = arrayOptional.get(); - VectorField point = new VectorField(name(), array, fieldType); + if (!bytesArrayOptional.isPresent()) { + return; + } + final byte[] array = bytesArrayOptional.get(); + VectorField point = new VectorField(name(), array, fieldType); - context.doc().add(point); - if (fieldType.stored()) { - context.doc().add(new StoredField(name(), point.toString())); + context.doc().add(point); + addStoredFieldForVectorField(context, fieldType, name(), point.toString()); + } else if (VectorDataType.FLOAT == vectorDataType) { + Optional floatsArrayOptional = getFloatsFromContext(context, dimension); + + if (!floatsArrayOptional.isPresent()) { + return; + } + final float[] array = floatsArrayOptional.get(); + VectorField point = new VectorField(name(), array, fieldType); + + context.doc().add(point); + addStoredFieldForVectorField(context, fieldType, name(), point.toString()); + } else { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Cannot parse context for unsupported values provided for field [%s]", VECTOR_DATA_TYPE_FIELD) + ); } + context.path().remove(); } @@ -459,50 +535,65 @@ void validateIfKNNPluginEnabled() { } } - Optional getFloatsFromContext(ParseContext context, int dimension) throws IOException { + // Returns an optional array of byte values where each value in the vector is parsed as a float and validated + // if it is a finite number without any decimals and within the byte range of [-128 to 127]. + Optional getBytesFromContext(ParseContext context, int dimension) throws IOException { context.path().add(simpleName()); - ArrayList vector = new ArrayList<>(); + ArrayList vector = new ArrayList<>(); XContentParser.Token token = context.parser().currentToken(); float value; + if (token == XContentParser.Token.START_ARRAY) { token = context.parser().nextToken(); while (token != XContentParser.Token.END_ARRAY) { value = context.parser().floatValue(); - - if (Float.isNaN(value)) { - throw new IllegalArgumentException("KNN vector values cannot be NaN"); - } - - if (Float.isInfinite(value)) { - throw new IllegalArgumentException("KNN vector values cannot be infinity"); - } - - vector.add(value); + validateByteVectorValue(value); + vector.add((byte) value); token = context.parser().nextToken(); } } else if (token == XContentParser.Token.VALUE_NUMBER) { value = context.parser().floatValue(); + validateByteVectorValue(value); + vector.add((byte) value); + context.parser().nextToken(); + } else if (token == XContentParser.Token.VALUE_NULL) { + context.path().remove(); + return Optional.empty(); + } + validateVectorDimension(dimension, vector.size()); + byte[] array = new byte[vector.size()]; + int i = 0; + for (Byte f : vector) { + array[i++] = f; + } + return Optional.of(array); + } - if (Float.isNaN(value)) { - throw new IllegalArgumentException("KNN vector values cannot be NaN"); - } + Optional getFloatsFromContext(ParseContext context, int dimension) throws IOException { + context.path().add(simpleName()); - if (Float.isInfinite(value)) { - throw new IllegalArgumentException("KNN vector values cannot be infinity"); + ArrayList vector = new ArrayList<>(); + XContentParser.Token token = context.parser().currentToken(); + float value; + if (token == XContentParser.Token.START_ARRAY) { + token = context.parser().nextToken(); + while (token != XContentParser.Token.END_ARRAY) { + value = context.parser().floatValue(); + validateFloatVectorValue(value); + vector.add(value); + token = context.parser().nextToken(); } - + } else if (token == XContentParser.Token.VALUE_NUMBER) { + value = context.parser().floatValue(); + validateFloatVectorValue(value); vector.add(value); context.parser().nextToken(); } else if (token == XContentParser.Token.VALUE_NULL) { context.path().remove(); return Optional.empty(); } - - if (dimension != vector.size()) { - String errorMessage = String.format("Vector dimension mismatch. Expected: %d, Given: %d", dimension, vector.size()); - throw new IllegalArgumentException(errorMessage); - } + validateVectorDimension(dimension, vector.size()); float[] array = new float[vector.size()]; int i = 0; diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java new file mode 100644 index 000000000..bf331eeb3 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java @@ -0,0 +1,164 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.mapper; + +import org.apache.lucene.document.FieldType; +import org.apache.lucene.document.StoredField; +import org.apache.lucene.index.DocValuesType; +import org.opensearch.index.mapper.ParametrizedFieldMapper; +import org.opensearch.index.mapper.ParseContext; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.util.KNNEngine; + +import java.util.Locale; + +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; + +public class KNNVectorFieldMapperUtil { + /** + * Validate the float vector value and throw exception if it is not a number or not in the finite range. + * + * @param value float vector value + */ + public static void validateFloatVectorValue(float value) { + if (Float.isNaN(value)) { + throw new IllegalArgumentException("KNN vector values cannot be NaN"); + } + + if (Float.isInfinite(value)) { + throw new IllegalArgumentException("KNN vector values cannot be infinity"); + } + } + + /** + * Validate the float vector value in the byte range if it is a finite number, + * with no decimal values and in the byte range of [-128 to 127]. If not throw IllegalArgumentException. + * + * @param value float value in byte range + */ + public static void validateByteVectorValue(float value) { + validateFloatVectorValue(value); + if (value % 1 != 0) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "[%s] field was set as [%s] in index mapping. But, KNN vector values are floats instead of byte integers", + VECTOR_DATA_TYPE_FIELD, + VectorDataType.BYTE.getValue() + ) + + ); + } + if ((int) value < Byte.MIN_VALUE || (int) value > Byte.MAX_VALUE) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "[%s] field was set as [%s] in index mapping. But, KNN vector values are not within in the byte range [%d, %d]", + VECTOR_DATA_TYPE_FIELD, + VectorDataType.BYTE.getValue(), + Byte.MIN_VALUE, + Byte.MAX_VALUE + ) + ); + } + } + + /** + * Validate if the given vector size matches with the dimension provided in mapping. + * + * @param dimension dimension of vector + * @param vectorSize size of the vector + */ + public static void validateVectorDimension(int dimension, int vectorSize) { + if (dimension != vectorSize) { + String errorMessage = String.format(Locale.ROOT, "Vector dimension mismatch. Expected: %d, Given: %d", dimension, vectorSize); + throw new IllegalArgumentException(errorMessage); + } + + } + + /** + * Validates and throws exception if data_type field is set in the index mapping + * using any VectorDataType (other than float, which is default) because other + * VectorDataTypes are only supported for lucene engine. + * + * @param vectorDataType VectorDataType Parameter + */ + public static void validateVectorDataTypeWithEngine(ParametrizedFieldMapper.Parameter vectorDataType) { + if (VectorDataType.FLOAT == vectorDataType.getValue()) { + return; + } + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "[%s] field with value [%s] is only supported for [%s] engine", + VECTOR_DATA_TYPE_FIELD, + vectorDataType.getValue().getValue(), + LUCENE_NAME + ) + ); + } + + /** + * Validates and throws exception if index.knn is set to true in the index settings + * using any VectorDataType (other than float, which is default) because we are using NMSLIB engine + * for LegacyFieldMapper, and it only supports float VectorDataType + * + * @param knnIndexSetting index.knn setting in the index settings + * @param vectorDataType VectorDataType Parameter + */ + public static void validateVectorDataTypeWithKnnIndexSetting( + boolean knnIndexSetting, + ParametrizedFieldMapper.Parameter vectorDataType + ) { + + if (VectorDataType.FLOAT == vectorDataType.getValue()) { + return; + } + if (knnIndexSetting) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "[%s] field with value [%s] is only supported for [%s] engine", + VECTOR_DATA_TYPE_FIELD, + vectorDataType.getValue().getValue(), + LUCENE_NAME + ) + ); + } + } + + /** + * @param knnEngine KNNEngine + * @return DocValues FieldType of type Binary + */ + public static FieldType buildDocValuesFieldType(KNNEngine knnEngine) { + FieldType field = new FieldType(); + field.putAttribute(KNN_ENGINE, knnEngine.getName()); + field.setDocValuesType(DocValuesType.BINARY); + field.freeze(); + return field; + } + + public static void addStoredFieldForVectorField( + ParseContext context, + FieldType fieldType, + String mapperName, + String vectorFieldAsString + ) { + if (fieldType.stored()) { + context.doc().add(new StoredField(mapperName, vectorFieldAsString)); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java index 5dcb09318..94e42ee7c 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java @@ -9,21 +9,24 @@ import lombok.Getter; import lombok.NonNull; import org.apache.lucene.document.FieldType; +import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnVectorField; -import org.apache.lucene.document.StoredField; -import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.VectorSimilarityFunction; import org.opensearch.common.Explicit; import org.opensearch.index.mapper.ParseContext; import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.VectorField; import org.opensearch.knn.index.util.KNNEngine; import java.io.IOException; +import java.util.Locale; import java.util.Optional; import static org.apache.lucene.index.VectorValues.MAX_DIMENSIONS; -import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.buildDocValuesFieldType; /** * Field mapper for case when Lucene has been set as an engine. @@ -34,6 +37,7 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper { /** FieldType used for initializing VectorField, which is used for creating binary doc values. **/ private final FieldType vectorFieldType; + private final VectorDataType vectorDataType; LuceneFieldMapper(final CreateLuceneFieldMapperInput input) { super( @@ -46,6 +50,7 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper { input.isHasDocValues() ); + vectorDataType = input.getVectorDataType(); this.knnMethod = input.getKnnMethodContext(); final VectorSimilarityFunction vectorSimilarityFunction = this.knnMethod.getSpaceType().getVectorSimilarityFunction(); @@ -53,6 +58,7 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper { if (dimension > LUCENE_MAX_DIMENSION) { throw new IllegalArgumentException( String.format( + Locale.ROOT, "Dimension value cannot be greater than [%s] but got [%s] for vector [%s]", LUCENE_MAX_DIMENSION, dimension, @@ -61,7 +67,7 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper { ); } - this.fieldType = KnnVectorField.createFieldType(dimension, vectorSimilarityFunction); + this.fieldType = vectorDataType.createKnnVectorFieldType(dimension, vectorSimilarityFunction); if (this.hasDocValues) { this.vectorFieldType = buildDocValuesFieldType(this.knnMethod.getKnnEngine()); @@ -70,36 +76,46 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper { } } - private static FieldType buildDocValuesFieldType(KNNEngine knnEngine) { - FieldType field = new FieldType(); - field.putAttribute(KNN_ENGINE, knnEngine.getName()); - field.setDocValuesType(DocValuesType.BINARY); - field.freeze(); - return field; - } - @Override protected void parseCreateField(ParseContext context, int dimension) throws IOException { validateIfKNNPluginEnabled(); validateIfCircuitBreakerIsNotTriggered(); - Optional arrayOptional = getFloatsFromContext(context, dimension); + if (VectorDataType.BYTE == vectorDataType) { + Optional bytesArrayOptional = getBytesFromContext(context, dimension); + if (bytesArrayOptional.isEmpty()) { + return; + } + final byte[] array = bytesArrayOptional.get(); + KnnByteVectorField point = new KnnByteVectorField(name(), array, fieldType); - if (arrayOptional.isEmpty()) { - return; - } - final float[] array = arrayOptional.get(); + context.doc().add(point); + addStoredFieldForVectorField(context, fieldType, name(), point.toString()); - KnnVectorField point = new KnnVectorField(name(), array, fieldType); + if (hasDocValues && vectorFieldType != null) { + context.doc().add(new VectorField(name(), array, vectorFieldType)); + } + } else if (VectorDataType.FLOAT == vectorDataType) { + Optional floatsArrayOptional = getFloatsFromContext(context, dimension); - context.doc().add(point); - if (fieldType.stored()) { - context.doc().add(new StoredField(name(), point.toString())); - } + if (floatsArrayOptional.isEmpty()) { + return; + } + final float[] array = floatsArrayOptional.get(); + + KnnVectorField point = new KnnVectorField(name(), array, fieldType); - if (hasDocValues && vectorFieldType != null) { - context.doc().add(new VectorField(name(), array, vectorFieldType)); + context.doc().add(point); + addStoredFieldForVectorField(context, fieldType, name(), point.toString()); + + if (hasDocValues && vectorFieldType != null) { + context.doc().add(new VectorField(name(), array, vectorFieldType)); + } + } else { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Cannot parse context for unsupported values provided for field [%s]", VECTOR_DATA_TYPE_FIELD) + ); } context.path().remove(); @@ -126,6 +142,7 @@ static class CreateLuceneFieldMapperInput { Explicit ignoreMalformed; boolean stored; boolean hasDocValues; + VectorDataType vectorDataType; @NonNull KNNMethodContext knnMethodContext; } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index cb02aadd1..2b1950017 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -12,6 +12,7 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.knn.index.KNNClusterUtil; import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelDao; @@ -32,6 +33,8 @@ import java.util.List; import java.util.Objects; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVectorValue; + /** * Helper class to build the KNN query */ @@ -266,6 +269,7 @@ protected Query doToQuery(QueryShardContext context) { int fieldDimension = knnVectorFieldType.getDimension(); KNNMethodContext knnMethodContext = knnVectorFieldType.getKnnMethodContext(); KNNEngine knnEngine = KNNEngine.DEFAULT; + VectorDataType vectorDataType = knnVectorFieldType.getVectorDataType(); if (fieldDimension == -1) { // If dimension is not set, the field uses a model and the information needs to be retrieved from there @@ -283,6 +287,15 @@ protected Query doToQuery(QueryShardContext context) { ); } + byte[] byteVector = new byte[0]; + if (VectorDataType.BYTE == vectorDataType) { + byteVector = new byte[vector.length]; + for (int i = 0; i < vector.length; i++) { + validateByteVectorValue(vector[i]); + byteVector[i] = (byte) vector[i]; + } + } + if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine) && filter != null && !KNNEngine.getEnginesThatSupportsFilters().contains(knnEngine)) { @@ -294,7 +307,9 @@ protected Query doToQuery(QueryShardContext context) { .knnEngine(knnEngine) .indexName(indexName) .fieldName(this.fieldName) - .vector(this.vector) + .vector(VectorDataType.FLOAT == vectorDataType ? this.vector : null) + .byteVector(VectorDataType.BYTE == vectorDataType ? byteVector : null) + .vectorDataType(vectorDataType) .k(this.k) .filter(this.filter) .context(context) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index 20c456c4a..65c15499d 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -11,15 +11,21 @@ import lombok.NonNull; import lombok.Setter; import lombok.extern.log4j.Log4j2; +import org.apache.lucene.search.KnnByteVectorQuery; import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryShardContext; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import java.io.IOException; +import java.util.Locale; import java.util.Optional; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES; + /** * Creates the Lucene k-NN queries */ @@ -36,12 +42,20 @@ public class KNNQueryFactory { * @param k the number of nearest neighbors to return * @return Lucene Query */ - public static Query create(KNNEngine knnEngine, String indexName, String fieldName, float[] vector, int k) { + public static Query create( + KNNEngine knnEngine, + String indexName, + String fieldName, + float[] vector, + int k, + VectorDataType vectorDataType + ) { final CreateQueryRequest createQueryRequest = CreateQueryRequest.builder() .knnEngine(knnEngine) .indexName(indexName) .fieldName(fieldName) .vector(vector) + .vectorDataType(vectorDataType) .k(k) .build(); return create(createQueryRequest); @@ -59,6 +73,8 @@ public static Query create(CreateQueryRequest createQueryRequest) { final String fieldName = createQueryRequest.getFieldName(); final int k = createQueryRequest.getK(); final float[] vector = createQueryRequest.getVector(); + final byte[] byteVector = createQueryRequest.getByteVector(); + final VectorDataType vectorDataType = createQueryRequest.getVectorDataType(); final Query filterQuery = getFilterQuery(createQueryRequest); if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) { @@ -77,14 +93,54 @@ public static Query create(CreateQueryRequest createQueryRequest) { return new KNNQuery(fieldName, vector, k, indexName); } + if (VectorDataType.BYTE == vectorDataType) { + return getKnnByteVectorQuery(indexName, fieldName, byteVector, k, filterQuery); + } else if (VectorDataType.FLOAT == vectorDataType) { + return getKnnFloatVectorQuery(indexName, fieldName, vector, k, filterQuery); + } else { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "Invalid value provided for [%s] field. Supported values are [%s]", + VECTOR_DATA_TYPE_FIELD, + SUPPORTED_VECTOR_DATA_TYPES + ) + ); + } + } + + private static Query getKnnByteVectorQuery(String indexName, String fieldName, byte[] byteVector, int k, Query filterQuery) { + if (filterQuery != null) { + log.debug( + String.format( + Locale.ROOT, + "Creating Lucene k-NN query with filters for index: %s \"\", field: %s \"\", k: %d", + indexName, + fieldName, + k + ) + ); + return new KnnByteVectorQuery(fieldName, byteVector, k, filterQuery); + } + log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); + return new KnnByteVectorQuery(fieldName, byteVector, k); + } + + private static Query getKnnFloatVectorQuery(String indexName, String fieldName, float[] floatVector, int k, Query filterQuery) { if (filterQuery != null) { log.debug( - String.format("Creating Lucene k-NN query with filters for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k) + String.format( + Locale.ROOT, + "Creating Lucene k-NN query with filters for index: %s \"\", field: %s \"\", k: %d", + indexName, + fieldName, + k + ) ); - return new KnnFloatVectorQuery(fieldName, vector, k, filterQuery); + return new KnnFloatVectorQuery(fieldName, floatVector, k, filterQuery); } log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); - return new KnnFloatVectorQuery(fieldName, vector, k); + return new KnnFloatVectorQuery(fieldName, floatVector, k); } private static Query getFilterQuery(CreateQueryRequest createQueryRequest) { @@ -126,6 +182,10 @@ static class CreateQueryRequest { @Getter private float[] vector; @Getter + private byte[] byteVector; + @Getter + private VectorDataType vectorDataType; + @Getter private int k; // can be null in cases filter not passed with the knn query private QueryBuilder filter; diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java index ca7526dcb..16bf6e204 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java @@ -54,7 +54,11 @@ public L2(Object query, MappedFieldType fieldType) { throw new IllegalArgumentException("Incompatible field_type for l2 space. The field type must " + "be knn_vector."); } - this.processedQuery = parseToFloatArray(query, ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension()); + this.processedQuery = parseToFloatArray( + query, + ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(), + ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType() + ); this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l2Squared(q, v)); } @@ -81,7 +85,11 @@ public CosineSimilarity(Object query, MappedFieldType fieldType) { throw new IllegalArgumentException("Incompatible field_type for cosine space. The field type must " + "be knn_vector."); } - this.processedQuery = parseToFloatArray(query, ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension()); + this.processedQuery = parseToFloatArray( + query, + ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(), + ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType() + ); float qVectorSquaredMagnitude = getVectorMagnitudeSquared(this.processedQuery); this.scoringMethod = (float[] q, float[] v) -> 1 + KNNScoringUtil.cosinesimilOptimized(q, v, qVectorSquaredMagnitude); } @@ -159,7 +167,11 @@ public L1(Object query, MappedFieldType fieldType) { throw new IllegalArgumentException("Incompatible field_type for l1 space. The field type must " + "be knn_vector."); } - this.processedQuery = parseToFloatArray(query, ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension()); + this.processedQuery = parseToFloatArray( + query, + ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(), + ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType() + ); this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l1Norm(q, v)); } @@ -185,7 +197,11 @@ public LInf(Object query, MappedFieldType fieldType) { throw new IllegalArgumentException("Incompatible field_type for l-inf space. The field type must " + "be knn_vector."); } - this.processedQuery = parseToFloatArray(query, ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension()); + this.processedQuery = parseToFloatArray( + query, + ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(), + ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType() + ); this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.lInfNorm(q, v)); } @@ -213,7 +229,11 @@ public InnerProd(Object query, MappedFieldType fieldType) { ); } - this.processedQuery = parseToFloatArray(query, ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension()); + this.processedQuery = parseToFloatArray( + query, + ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(), + ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType() + ); this.scoringMethod = (float[] q, float[] v) -> KNNWeight.normalizeScore(-KNNScoringUtil.innerProduct(q, v)); } diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java index 6f68d16b6..3ec1a9941 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java @@ -5,6 +5,7 @@ package org.opensearch.knn.plugin.script; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.plugin.stats.KNNCounter; import org.opensearch.index.mapper.BinaryFieldMapper; @@ -16,6 +17,7 @@ import java.util.Base64; import static org.opensearch.index.mapper.NumberFieldMapper.NumberType.LONG; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVectorValue; public class KNNScoringSpaceUtil { @@ -85,8 +87,8 @@ public static BigInteger parseToBigInteger(Object object) { * @param expectedDimensions int representing the expected dimension of this array. * @return float[] of the object */ - public static float[] parseToFloatArray(Object object, int expectedDimensions) { - float[] floatArray = convertVectorToPrimitive(object); + public static float[] parseToFloatArray(Object object, int expectedDimensions, VectorDataType vectorDataType) { + float[] floatArray = convertVectorToPrimitive(object, vectorDataType); if (expectedDimensions != floatArray.length) { KNNCounter.SCRIPT_QUERY_ERRORS.increment(); throw new IllegalStateException( @@ -103,13 +105,17 @@ public static float[] parseToFloatArray(Object object, int expectedDimensions) { * @return Float array representing the vector */ @SuppressWarnings("unchecked") - public static float[] convertVectorToPrimitive(Object vector) { + public static float[] convertVectorToPrimitive(Object vector, VectorDataType vectorDataType) { float[] primitiveVector = null; if (vector != null) { - final ArrayList tmp = (ArrayList) vector; + final ArrayList tmp = (ArrayList) vector; primitiveVector = new float[tmp.size()]; for (int i = 0; i < primitiveVector.length; i++) { - primitiveVector[i] = tmp.get(i).floatValue(); + float value = tmp.get(i).floatValue(); + if (VectorDataType.BYTE == vectorDataType) { + validateByteVectorValue(value); + } + primitiveVector[i] = value; } } return primitiveVector; diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java index 90468c2e7..130c4d8e0 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java @@ -8,11 +8,14 @@ import org.opensearch.knn.index.KNNVectorScriptDocValues; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.knn.index.VectorDataType; import java.math.BigInteger; import java.util.List; import java.util.Objects; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVectorValue; + public class KNNScoringUtil { private static Logger logger = LogManager.getLogger(KNNScoringUtil.class); @@ -54,12 +57,16 @@ public static float l2Squared(float[] queryVector, float[] inputVector) { return squaredDistance; } - private static float[] toFloat(List inputVector) { + private static float[] toFloat(List inputVector, VectorDataType vectorDataType) { Objects.requireNonNull(inputVector); float[] value = new float[inputVector.size()]; int index = 0; for (final Number val : inputVector) { - value[index++] = val.floatValue(); + float floatValue = val.floatValue(); + if (VectorDataType.BYTE == vectorDataType) { + validateByteVectorValue(floatValue); + } + value[index++] = floatValue; } return value; } @@ -81,7 +88,7 @@ private static float[] toFloat(List inputVector) { * @return L2 score */ public static float l2Squared(List queryVector, KNNVectorScriptDocValues docValues) { - return l2Squared(toFloat(queryVector), docValues.getValue()); + return l2Squared(toFloat(queryVector, docValues.getVectorDataType()), docValues.getValue()); } /** @@ -127,7 +134,11 @@ public static float cosinesimilOptimized(float[] queryVector, float[] inputVecto * @return cosine score */ public static float cosineSimilarity(List queryVector, KNNVectorScriptDocValues docValues, Number queryVectorMagnitude) { - return cosinesimilOptimized(toFloat(queryVector), docValues.getValue(), queryVectorMagnitude.floatValue()); + return cosinesimilOptimized( + toFloat(queryVector, docValues.getVectorDataType()), + docValues.getValue(), + queryVectorMagnitude.floatValue() + ); } /** @@ -172,7 +183,7 @@ public static float cosinesimil(float[] queryVector, float[] inputVector) { * @return cosine score */ public static float cosineSimilarity(List queryVector, KNNVectorScriptDocValues docValues) { - return cosinesimil(toFloat(queryVector), docValues.getValue()); + return cosinesimil(toFloat(queryVector, docValues.getVectorDataType()), docValues.getValue()); } /** @@ -232,7 +243,7 @@ public static float l1Norm(float[] queryVector, float[] inputVector) { * @return L1 score */ public static float l1Norm(List queryVector, KNNVectorScriptDocValues docValues) { - return l1Norm(toFloat(queryVector), docValues.getValue()); + return l1Norm(toFloat(queryVector, docValues.getVectorDataType()), docValues.getValue()); } /** @@ -270,7 +281,7 @@ public static float lInfNorm(float[] queryVector, float[] inputVector) { * @return L-inf score */ public static float lInfNorm(List queryVector, KNNVectorScriptDocValues docValues) { - return lInfNorm(toFloat(queryVector), docValues.getValue()); + return lInfNorm(toFloat(queryVector, docValues.getVectorDataType()), docValues.getValue()); } /** @@ -307,6 +318,6 @@ public static float innerProduct(float[] queryVector, float[] inputVector) { * @return inner product score */ public static float innerProduct(List queryVector, KNNVectorScriptDocValues docValues) { - return innerProduct(toFloat(queryVector), docValues.getValue()); + return innerProduct(toFloat(queryVector, docValues.getVectorDataType()), docValues.getValue()); } } diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java index 8bda1aefc..cbe11dd6b 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java @@ -62,30 +62,38 @@ public void tearDown() throws Exception { } public void testGetScriptValues() { - KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData(leafReaderContext.reader(), MOCK_INDEX_FIELD_NAME); + KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData( + leafReaderContext.reader(), + MOCK_INDEX_FIELD_NAME, + VectorDataType.FLOAT + ); ScriptDocValues scriptValues = leafFieldData.getScriptValues(); assertNotNull(scriptValues); assertTrue(scriptValues instanceof KNNVectorScriptDocValues); } public void testGetScriptValuesWrongFieldName() { - KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData(leafReaderContext.reader(), "invalid"); + KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData(leafReaderContext.reader(), "invalid", VectorDataType.FLOAT); ScriptDocValues scriptValues = leafFieldData.getScriptValues(); assertNotNull(scriptValues); } public void testGetScriptValuesWrongFieldType() { - KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData(leafReaderContext.reader(), MOCK_NUMERIC_INDEX_FIELD_NAME); + KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData( + leafReaderContext.reader(), + MOCK_NUMERIC_INDEX_FIELD_NAME, + VectorDataType.FLOAT + ); expectThrows(IllegalStateException.class, () -> leafFieldData.getScriptValues()); } public void testRamBytesUsed() { - KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData(leafReaderContext.reader(), ""); + KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData(leafReaderContext.reader(), "", VectorDataType.FLOAT); assertEquals(0, leafFieldData.ramBytesUsed()); } public void testGetBytesValues() { - KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData(leafReaderContext.reader(), ""); + KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData(leafReaderContext.reader(), "", VectorDataType.FLOAT); expectThrows(UnsupportedOperationException.class, () -> leafFieldData.getBytesValues()); } } diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorIndexFieldDataTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorIndexFieldDataTests.java index 8523c4146..ee57cb190 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorIndexFieldDataTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorIndexFieldDataTests.java @@ -27,7 +27,7 @@ public class KNNVectorIndexFieldDataTests extends KNNTestCase { @Before public void setUp() throws Exception { super.setUp(); - indexFieldData = new KNNVectorIndexFieldData(MOCK_INDEX_FIELD_NAME, CoreValuesSourceType.BYTES); + indexFieldData = new KNNVectorIndexFieldData(MOCK_INDEX_FIELD_NAME, CoreValuesSourceType.BYTES, VectorDataType.FLOAT); directory = newDirectory(); createEmptyDocument(directory); } diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java index 876117940..a0df3ce64 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java @@ -37,7 +37,8 @@ public void setUp() throws Exception { LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); scriptDocValues = new KNNVectorScriptDocValues( leafReaderContext.reader().getBinaryDocValues(MOCK_INDEX_FIELD_NAME), - MOCK_INDEX_FIELD_NAME + MOCK_INDEX_FIELD_NAME, + VectorDataType.FLOAT ); } diff --git a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java index b05211b25..0c1f0a451 100644 --- a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java +++ b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java @@ -9,6 +9,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.primitives.Floats; import org.apache.http.util.EntityUtils; +import lombok.SneakyThrows; import org.apache.commons.lang.math.RandomUtils; import org.apache.lucene.index.VectorSimilarityFunction; import org.junit.After; @@ -34,8 +35,10 @@ import java.util.function.Function; import java.util.stream.Collectors; +import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; public class LuceneEngineIT extends KNNRestTestCase { @@ -110,7 +113,7 @@ public void testQuery_innerProduct_notSupported() throws Exception { public void testQuery_invalidVectorDimensionInQuery() throws Exception { - createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2); + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2, VectorDataType.FLOAT); for (int j = 0; j < TEST_INDEX_VECTORS.length; j++) { addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INDEX_VECTORS[j]); } @@ -127,7 +130,7 @@ public void testQuery_documentsMissingField() throws Exception { SpaceType spaceType = SpaceType.L2; - createKnnIndexMappingWithLuceneEngine(DIMENSION, spaceType); + createKnnIndexMappingWithLuceneEngine(DIMENSION, spaceType, VectorDataType.FLOAT); for (int j = 0; j < TEST_INDEX_VECTORS.length; j++) { addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INDEX_VECTORS[j]); } @@ -224,35 +227,35 @@ public void testAddDoc() throws IOException { Float[] vector = new Float[] { 2.0f, 4.5f, 6.5f }; addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector); - refreshAllIndices(); + refreshIndex(INDEX_NAME); assertEquals(1, getDocCount(INDEX_NAME)); } public void testUpdateDoc() throws Exception { - createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2); + createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.FLOAT); Float[] vector = { 6.0f, 6.0f }; addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector); Float[] updatedVector = { 8.0f, 8.0f }; updateKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, updatedVector); - refreshAllIndices(); + refreshIndex(INDEX_NAME); assertEquals(1, getDocCount(INDEX_NAME)); } public void testDeleteDoc() throws Exception { - createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2); + createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.FLOAT); Float[] vector = { 6.0f, 6.0f }; addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector); deleteKnnDoc(INDEX_NAME, DOC_ID); - refreshAllIndices(); + refreshIndex(INDEX_NAME); assertEquals(0, getDocCount(INDEX_NAME)); } - public void testQueryWithFilter() throws Exception { - createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2); + public void testQueryWithFilterUsingFloatVectorDataType() throws Exception { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2, VectorDataType.FLOAT); addKnnDocWithAttributes( DOC_ID, @@ -262,39 +265,28 @@ public void testQueryWithFilter() throws Exception { addKnnDocWithAttributes(DOC_ID_2, new float[] { 3.2f, 2.1f, 4.8f }, ImmutableMap.of(COLOR_FIELD_NAME, "green")); addKnnDocWithAttributes(DOC_ID_3, new float[] { 4.1f, 5.0f, 7.1f }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); - refreshAllIndices(); + refreshIndex(INDEX_NAME); final float[] searchVector = { 6.0f, 6.0f, 4.1f }; - int kGreaterThanFilterResult = 5; - List expectedDocIds = Arrays.asList(DOC_ID, DOC_ID_3); - final Response response = searchKNNIndex( - INDEX_NAME, - new KNNQueryBuilder(FIELD_NAME, searchVector, kGreaterThanFilterResult, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")), - kGreaterThanFilterResult - ); - final String responseBody = EntityUtils.toString(response.getEntity()); - final List knnResults = parseSearchResponse(responseBody, FIELD_NAME); + List expectedDocIdsKGreaterThanFilterResult = Arrays.asList(DOC_ID, DOC_ID_3); + List expectedDocIdsKLimitsFilterResult = Arrays.asList(DOC_ID); + validateQueryResultsWithFilters(searchVector, 5, 1, expectedDocIdsKGreaterThanFilterResult, expectedDocIdsKLimitsFilterResult); + } - assertEquals(expectedDocIds.size(), knnResults.size()); - assertTrue(knnResults.stream().map(KNNResult::getDocId).collect(Collectors.toList()).containsAll(expectedDocIds)); + @SneakyThrows + public void testQueryWithFilterUsingByteVectorDataType() { + createKnnIndexMappingWithLuceneEngine(3, SpaceType.L2, VectorDataType.BYTE); - int kLimitsFilterResult = 1; - List expectedDocIdsKLimitsFilterResult = Arrays.asList(DOC_ID); - final Response responseKLimitsFilterResult = searchKNNIndex( - INDEX_NAME, - new KNNQueryBuilder(FIELD_NAME, searchVector, kLimitsFilterResult, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")), - kLimitsFilterResult - ); - final String responseBodyKLimitsFilterResult = EntityUtils.toString(responseKLimitsFilterResult.getEntity()); - final List knnResultsKLimitsFilterResult = parseSearchResponse(responseBodyKLimitsFilterResult, FIELD_NAME); + addKnnDocWithAttributes(DOC_ID, new float[] { 6.0f, 7.0f, 3.0f }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); + addKnnDocWithAttributes(DOC_ID_2, new float[] { 3.0f, 2.0f, 4.0f }, ImmutableMap.of(COLOR_FIELD_NAME, "green")); + addKnnDocWithAttributes(DOC_ID_3, new float[] { 4.0f, 5.0f, 7.0f }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); - assertEquals(expectedDocIdsKLimitsFilterResult.size(), knnResultsKLimitsFilterResult.size()); - assertTrue( - knnResultsKLimitsFilterResult.stream() - .map(KNNResult::getDocId) - .collect(Collectors.toList()) - .containsAll(expectedDocIdsKLimitsFilterResult) - ); + refreshIndex(INDEX_NAME); + + final float[] searchVector = { 6.0f, 6.0f, 4.0f }; + List expectedDocIdsKGreaterThanFilterResult = Arrays.asList(DOC_ID, DOC_ID_3); + List expectedDocIdsKLimitsFilterResult = Arrays.asList(DOC_ID); + validateQueryResultsWithFilters(searchVector, 5, 1, expectedDocIdsKGreaterThanFilterResult, expectedDocIdsKLimitsFilterResult); } public void testQuery_filterWithNonLuceneEngine() throws Exception { @@ -337,7 +329,7 @@ public void testQuery_filterWithNonLuceneEngine() throws Exception { } public void testIndexReopening() throws Exception { - createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2); + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2, VectorDataType.FLOAT); for (int j = 0; j < TEST_INDEX_VECTORS.length; j++) { addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INDEX_VECTORS[j]); @@ -358,13 +350,14 @@ public void testIndexReopening() throws Exception { assertArrayEquals(knnResultsBeforeIndexClosure.toArray(), knnResultsAfterIndexClosure.toArray()); } - private void createKnnIndexMappingWithLuceneEngine(int dimension, SpaceType spaceType) throws Exception { + private void createKnnIndexMappingWithLuceneEngine(int dimension, SpaceType spaceType, VectorDataType vectorDataType) throws Exception { XContentBuilder builder = XContentFactory.jsonBuilder() .startObject() .startObject(PROPERTIES_FIELD_NAME) .startObject(FIELD_NAME) .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) .field(DIMENSION_FIELD_NAME, dimension) + .field(VECTOR_DATA_TYPE_FIELD, vectorDataType) .startObject(KNNConstants.KNN_METHOD) .field(KNNConstants.NAME, KNNEngine.LUCENE.getMethod(METHOD_HNSW).getMethodComponent().getName()) .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) @@ -384,7 +377,7 @@ private void createKnnIndexMappingWithLuceneEngine(int dimension, SpaceType spac private void baseQueryTest(SpaceType spaceType) throws Exception { - createKnnIndexMappingWithLuceneEngine(DIMENSION, spaceType); + createKnnIndexMappingWithLuceneEngine(DIMENSION, spaceType, VectorDataType.FLOAT); for (int j = 0; j < TEST_INDEX_VECTORS.length; j++) { addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INDEX_VECTORS[j]); } @@ -419,4 +412,42 @@ private List queryResults(final float[] searchVector, final int k) thro assertNotNull(knnResults); return knnResults.stream().map(KNNResult::getVector).collect(Collectors.toUnmodifiableList()); } + + @SneakyThrows + private void validateQueryResultsWithFilters( + float[] searchVector, + int kGreaterThanFilterResult, + int kLimitsFilterResult, + List expectedDocIdsKGreaterThanFilterResult, + List expectedDocIdsKLimitsFilterResult + ) { + final Response response = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder(FIELD_NAME, searchVector, kGreaterThanFilterResult, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")), + kGreaterThanFilterResult + ); + final String responseBody = EntityUtils.toString(response.getEntity()); + final List knnResults = parseSearchResponse(responseBody, FIELD_NAME); + + assertEquals(expectedDocIdsKGreaterThanFilterResult.size(), knnResults.size()); + assertTrue( + knnResults.stream().map(KNNResult::getDocId).collect(Collectors.toList()).containsAll(expectedDocIdsKGreaterThanFilterResult) + ); + + final Response responseKLimitsFilterResult = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder(FIELD_NAME, searchVector, kLimitsFilterResult, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")), + kLimitsFilterResult + ); + final String responseBodyKLimitsFilterResult = EntityUtils.toString(responseKLimitsFilterResult.getEntity()); + final List knnResultsKLimitsFilterResult = parseSearchResponse(responseBodyKLimitsFilterResult, FIELD_NAME); + + assertEquals(expectedDocIdsKLimitsFilterResult.size(), knnResultsKLimitsFilterResult.size()); + assertTrue( + knnResultsKLimitsFilterResult.stream() + .map(KNNResult::getDocId) + .collect(Collectors.toList()) + .containsAll(expectedDocIdsKLimitsFilterResult) + ); + } } diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java new file mode 100644 index 000000000..14ce819f8 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java @@ -0,0 +1,545 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index; + +import lombok.SneakyThrows; +import org.apache.http.util.EntityUtils; +import org.junit.After; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.client.ResponseException; +import org.opensearch.common.Strings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.KNNResult; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.rest.RestStatus; +import org.opensearch.script.Script; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import static org.opensearch.knn.common.KNNConstants.DIMENSION; +import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES; + +public class VectorDataTypeIT extends KNNRestTestCase { + private static final String INDEX_NAME = "test-index-vec-dt"; + private static final String FIELD_NAME = "test-field-vec-dt"; + private static final String PROPERTIES_FIELD = "properties"; + private static final String DOC_ID = "doc1"; + private static final String TYPE_FIELD_NAME = "type"; + private static final String KNN_VECTOR_TYPE = "knn_vector"; + private static final int EF_CONSTRUCTION = 128; + private static final int M = 16; + private static final QueryBuilder MATCH_ALL_QUERY_BUILDER = new MatchAllQueryBuilder(); + + @After + @SneakyThrows + public final void cleanUp() { + deleteKNNIndex(INDEX_NAME); + } + + // Validate if we are able to create an index by setting data_type field as byte and add a doc to it + @SneakyThrows + public void testAddDocWithByteVector() { + createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()); + Byte[] vector = { 6, 6 }; + addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector); + + refreshAllIndices(); + assertEquals(1, getDocCount(INDEX_NAME)); + } + + // Validate by creating an index by setting data_type field as byte, add a doc to it and update it later. + @SneakyThrows + public void testUpdateDocWithByteVector() { + createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()); + Byte[] vector = { -36, 78 }; + addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector); + + Byte[] updatedVector = { 89, -8 }; + updateKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, updatedVector); + + refreshAllIndices(); + assertEquals(1, getDocCount(INDEX_NAME)); + } + + // Validate by creating an index by setting data_type field as byte, add a doc to it and delete it later. + @SneakyThrows + public void testDeleteDocWithByteVector() { + createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()); + Byte[] vector = { 35, -46 }; + addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector); + + deleteKnnDoc(INDEX_NAME, DOC_ID); + refreshAllIndices(); + + assertEquals(0, getDocCount(INDEX_NAME)); + } + + @SneakyThrows + public void testSearchWithByteVector() { + createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()); + ingestL2ByteTestData(); + + Byte[] queryVector = { 1, 1 }; + Response response = searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, convertByteToFloatArray(queryVector), 4), 4); + + validateL2SearchResults(response); + } + + @SneakyThrows + public void testSearchWithInvalidByteVector() { + createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()); + ingestL2ByteTestData(); + + // Validate search with floats instead of byte vectors + float[] queryVector = { -10.76f, 15.89f }; + ResponseException ex = expectThrows( + ResponseException.class, + () -> searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, queryVector, 4), 4) + ); + assertTrue( + ex.getMessage() + .contains( + String.format( + Locale.ROOT, + "[%s] field was set as [%s] in index mapping. But, KNN vector values are floats instead of byte integers", + VECTOR_DATA_TYPE_FIELD, + VectorDataType.BYTE.getValue() + ) + ) + ); + + // validate search with search vectors outside of byte range + float[] queryVector1 = { -1000.0f, 200.0f }; + ResponseException ex1 = expectThrows( + ResponseException.class, + () -> searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, queryVector1, 4), 4) + ); + + assertTrue( + ex1.getMessage() + .contains( + String.format( + Locale.ROOT, + "[%s] field was set as [%s] in index mapping. But, KNN vector values are not within in the byte range [%d, %d]", + VECTOR_DATA_TYPE_FIELD, + VectorDataType.BYTE.getValue(), + Byte.MIN_VALUE, + Byte.MAX_VALUE + ) + ) + ); + } + + @SneakyThrows + public void testSearchWithFloatVectorDataType() { + createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.FLOAT.getValue()); + ingestL2FloatTestData(); + + float[] queryVector = { 1.0f, 1.0f }; + Response response = searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, queryVector, 4), 4); + + validateL2SearchResults(response); + } + + // Set an invalid value for data_type field while creating the index which should throw an exception + public void testInvalidVectorDataType() { + String vectorDataType = "invalidVectorType"; + ResponseException ex = expectThrows( + ResponseException.class, + () -> createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, vectorDataType) + ); + assertTrue( + ex.getMessage() + .contains( + String.format( + Locale.ROOT, + "Invalid value provided for [%s] field. Supported values are [%s]", + VECTOR_DATA_TYPE_FIELD, + SUPPORTED_VECTOR_DATA_TYPES + ) + ) + ); + } + + // Set null value for data_type field while creating the index which should throw an exception + public void testVectorDataTypeAsNull() { + ResponseException ex = expectThrows(ResponseException.class, () -> createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, null)); + assertTrue( + ex.getMessage() + .contains( + String.format( + Locale.ROOT, + "[%s] on mapper [%s] of type [%s] must not have a [null] value", + VECTOR_DATA_TYPE_FIELD, + FIELD_NAME, + KNN_VECTOR_TYPE + ) + ) + ); + } + + // Create an index with byte vector data_type and add a doc with decimal values which should throw exception + @SneakyThrows + public void testInvalidVectorData() { + createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()); + Float[] vector = { -10.76f, 15.89f }; + + ResponseException ex = expectThrows(ResponseException.class, () -> addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector)); + assertTrue( + ex.getMessage() + .contains( + String.format( + Locale.ROOT, + "[%s] field was set as [%s] in index mapping. But, KNN vector values are floats instead of byte integers", + VECTOR_DATA_TYPE_FIELD, + VectorDataType.BYTE.getValue() + ) + ) + ); + } + + // Create an index with byte vector data_type and add a doc with values out of byte range which should throw exception + @SneakyThrows + public void testInvalidByteVectorRange() { + createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()); + Float[] vector = { -1000f, 155f }; + + ResponseException ex = expectThrows(ResponseException.class, () -> addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector)); + assertTrue( + ex.getMessage() + .contains( + String.format( + Locale.ROOT, + "[%s] field was set as [%s] in index mapping. But, KNN vector values are not within in the byte range [%d, %d]", + VECTOR_DATA_TYPE_FIELD, + VectorDataType.BYTE.getValue(), + Byte.MIN_VALUE, + Byte.MAX_VALUE + ) + ) + ); + } + + // Create an index with byte vector data_type using nmslib engine which should throw an exception + public void testByteVectorDataTypeWithNmslibEngine() { + ResponseException ex = expectThrows( + ResponseException.class, + () -> createKnnIndexMappingWithNmslibEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()) + ); + assertTrue( + ex.getMessage() + .contains( + String.format( + Locale.ROOT, + "[%s] field with value [%s] is only supported for [%s] engine", + VECTOR_DATA_TYPE_FIELD, + VectorDataType.BYTE.getValue(), + LUCENE_NAME + ) + ) + ); + } + + @SneakyThrows + public void testByteVectorDataTypeWithLegacyFieldMapperKnnIndexSetting() { + // Create an index with byte vector data_type and index.knn as true without setting KnnMethodContext, + // which should throw an exception because the LegacyFieldMapper will use NMSLIB engine and byte data_type + // is not supported for NMSLIB engine. + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME) + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION, 2) + .field(VECTOR_DATA_TYPE_FIELD, VectorDataType.BYTE.getValue()) + .endObject() + .endObject() + .endObject(); + + String mapping = Strings.toString(builder); + + ResponseException ex = expectThrows(ResponseException.class, () -> createKnnIndex(INDEX_NAME, mapping)); + assertTrue( + ex.getMessage() + .contains( + String.format( + Locale.ROOT, + "[%s] field with value [%s] is only supported for [%s] engine", + VECTOR_DATA_TYPE_FIELD, + VectorDataType.BYTE.getValue(), + LUCENE_NAME + ) + ) + ); + + } + + public void testDocValuesWithByteVectorDataTypeLuceneEngine() throws Exception { + createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()); + ingestL2ByteTestData(); + + Byte[] queryVector = { 1, 1 }; + Request request = createScriptQueryRequest(queryVector, SpaceType.L2.getValue(), MATCH_ALL_QUERY_BUILDER); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + validateL2SearchResults(response); + } + + public void testDocValuesWithFloatVectorDataTypeLuceneEngine() throws Exception { + createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.FLOAT.getValue()); + ingestL2FloatTestData(); + + Byte[] queryVector = { 1, 1 }; + Request request = createScriptQueryRequest(queryVector, SpaceType.L2.getValue(), MATCH_ALL_QUERY_BUILDER); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + validateL2SearchResults(response); + } + + public void testL2ScriptScoreWithByteVectorDataType() throws Exception { + createKnnIndexMappingForScripting(2, VectorDataType.BYTE.getValue()); + ingestL2ByteTestData(); + + Byte[] queryVector = { 1, 1 }; + Request request = createScriptQueryRequest(queryVector, SpaceType.L2.getValue(), MATCH_ALL_QUERY_BUILDER); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + validateL2SearchResults(response); + } + + public void testL2ScriptScoreWithFloatVectorDataType() throws Exception { + createKnnIndexMappingForScripting(2, VectorDataType.FLOAT.getValue()); + ingestL2FloatTestData(); + + Float[] queryVector = { 1.0f, 1.0f }; + Request request = createScriptQueryRequest(queryVector, SpaceType.L2.getValue(), MATCH_ALL_QUERY_BUILDER); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + validateL2SearchResults(response); + } + + public void testL2PainlessScriptingWithByteVectorDataType() throws Exception { + createKnnIndexMappingForScripting(2, VectorDataType.BYTE.getValue()); + ingestL2ByteTestData(); + + String source = String.format("1/(1 + l2Squared([1, 1], doc['%s']))", FIELD_NAME); + Request request = constructScriptScoreContextSearchRequest( + INDEX_NAME, + MATCH_ALL_QUERY_BUILDER, + Collections.emptyMap(), + Script.DEFAULT_SCRIPT_LANG, + source, + 4 + ); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + validateL2SearchResults(response); + } + + public void testL2PainlessScriptingWithFloatVectorDataType() throws Exception { + createKnnIndexMappingForScripting(2, VectorDataType.FLOAT.getValue()); + ingestL2FloatTestData(); + + String source = String.format("1/(1 + l2Squared([1.0f, 1.0f], doc['%s']))", FIELD_NAME); + Request request = constructScriptScoreContextSearchRequest( + INDEX_NAME, + MATCH_ALL_QUERY_BUILDER, + Collections.emptyMap(), + Script.DEFAULT_SCRIPT_LANG, + source, + 4 + ); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + validateL2SearchResults(response); + } + + public void testKNNScriptScoreWithInvalidVectorDataType() { + // Set an invalid value for data_type field while creating the index for script scoring which should throw an exception + ResponseException ex = expectThrows(ResponseException.class, () -> createKnnIndexMappingForScripting(2, "invalid_data_type")); + assertTrue( + ex.getMessage() + .contains( + String.format( + Locale.ROOT, + "Invalid value provided for [%s] field. Supported values are [%s]", + VECTOR_DATA_TYPE_FIELD, + SUPPORTED_VECTOR_DATA_TYPES + ) + ) + ); + } + + public void testKNNScriptScoreWithInvalidByteQueryVector() throws Exception { + // Create an index with byte vector data_type, add docs and run a scoring script query with decimal values + // which should throw exception + createKnnIndexMappingForScripting(2, VectorDataType.BYTE.getValue()); + + Byte[] f1 = { 6, 6 }; + addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); + + Byte[] f2 = { 2, 2 }; + addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f2); + + // Construct Search Request with query vector having decimal values + Float[] queryVector = { 10.67f, 19.78f }; + Request request = createScriptQueryRequest(queryVector, SpaceType.L2.getValue(), MATCH_ALL_QUERY_BUILDER); + ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request)); + assertTrue( + ex.getMessage() + .contains( + String.format( + Locale.ROOT, + "[%s] field was set as [%s] in index mapping. But, KNN vector values are floats instead of byte integers", + VECTOR_DATA_TYPE_FIELD, + VectorDataType.BYTE.getValue() + ) + ) + ); + } + + @SneakyThrows + private void ingestL2ByteTestData() { + Byte[] b1 = { 6, 6 }; + addKnnDoc(INDEX_NAME, "1", FIELD_NAME, b1); + + Byte[] b2 = { 2, 2 }; + addKnnDoc(INDEX_NAME, "2", FIELD_NAME, b2); + + Byte[] b3 = { 4, 4 }; + addKnnDoc(INDEX_NAME, "3", FIELD_NAME, b3); + + Byte[] b4 = { 3, 3 }; + addKnnDoc(INDEX_NAME, "4", FIELD_NAME, b4); + } + + @SneakyThrows + private void ingestL2FloatTestData() { + Float[] f1 = { 6.0f, 6.0f }; + addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); + + Float[] f2 = { 2.0f, 2.0f }; + addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f2); + + Float[] f3 = { 4.0f, 4.0f }; + addKnnDoc(INDEX_NAME, "3", FIELD_NAME, f3); + + Float[] f4 = { 3.0f, 3.0f }; + addKnnDoc(INDEX_NAME, "4", FIELD_NAME, f4); + } + + private void createKnnIndexMappingWithNmslibEngine(int dimension, SpaceType spaceType, String vectorDataType) throws Exception { + createKnnIndexMappingWithCustomEngine(dimension, spaceType, vectorDataType, KNNEngine.NMSLIB.getName()); + } + + private void createKnnIndexMappingWithLuceneEngine(int dimension, SpaceType spaceType, String vectorDataType) throws Exception { + createKnnIndexMappingWithCustomEngine(dimension, spaceType, vectorDataType, KNNEngine.LUCENE.getName()); + } + + private void createKnnIndexMappingWithCustomEngine(int dimension, SpaceType spaceType, String vectorDataType, String engine) + throws Exception { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME) + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION, dimension) + .field(VECTOR_DATA_TYPE_FIELD, vectorDataType) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, METHOD_HNSW) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNNConstants.KNN_ENGINE, engine) + .startObject(KNNConstants.PARAMETERS) + .field(KNNConstants.METHOD_PARAMETER_M, M) + .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, EF_CONSTRUCTION) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + String mapping = Strings.toString(builder); + createKnnIndex(INDEX_NAME, mapping); + } + + private void createKnnIndexMappingForScripting(int dimension, String vectorDataType) throws Exception { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME) + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION, dimension) + .field(VECTOR_DATA_TYPE_FIELD, vectorDataType) + .endObject() + .endObject() + .endObject(); + + String mapping = Strings.toString(builder); + createKnnIndex(INDEX_NAME, Settings.EMPTY, mapping); + } + + @SneakyThrows + private Request createScriptQueryRequest(Byte[] queryVector, String spaceType, QueryBuilder qb) { + Map params = new HashMap<>(); + params.put("field", FIELD_NAME); + params.put("query_value", queryVector); + params.put("space_type", spaceType); + return constructKNNScriptQueryRequest(INDEX_NAME, qb, params); + } + + @SneakyThrows + private Request createScriptQueryRequest(Float[] queryVector, String spaceType, QueryBuilder qb) { + Map params = new HashMap<>(); + params.put("field", FIELD_NAME); + params.put("query_value", queryVector); + params.put("space_type", spaceType); + return constructKNNScriptQueryRequest(INDEX_NAME, qb, params); + } + + @SneakyThrows + private void validateL2SearchResults(Response response) { + + List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); + + assertEquals(4, results.size()); + + String[] expectedDocIDs = { "2", "4", "3", "1" }; + for (int i = 0; i < results.size(); i++) { + assertEquals(expectedDocIDs[i], results.get(i).getDocId()); + } + } + + private float[] convertByteToFloatArray(Byte[] arr) { + float[] floatArray = new float[arr.length]; + for (int i = 0; i < arr.length; i++) { + floatArray[i] = arr[i]; + } + return floatArray; + } +} diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java new file mode 100644 index 000000000..4423c85d8 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java @@ -0,0 +1,109 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index; + +import lombok.SneakyThrows; +import org.apache.lucene.document.BinaryDocValuesField; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.FieldType; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.analysis.MockAnalyzer; +import org.junit.Assert; +import org.opensearch.knn.KNNTestCase; + +import java.io.IOException; + +public class VectorDataTypeTests extends KNNTestCase { + + private static final String MOCK_FLOAT_INDEX_FIELD_NAME = "test-float-index-field-name"; + private static final String MOCK_BYTE_INDEX_FIELD_NAME = "test-byte-index-field-name"; + private static final float[] SAMPLE_FLOAT_VECTOR_DATA = new float[] { 10.0f, 25.0f }; + private static final byte[] SAMPLE_BYTE_VECTOR_DATA = new byte[] { 10, 25 }; + private Directory directory; + private DirectoryReader reader; + + @SneakyThrows + public void testGetDocValuesWithFloatVectorDataType() { + KNNVectorScriptDocValues scriptDocValues = getKNNFloatVectorScriptDocValues(); + + scriptDocValues.setNextDocId(0); + Assert.assertArrayEquals(SAMPLE_FLOAT_VECTOR_DATA, scriptDocValues.getValue(), 0.1f); + + reader.close(); + directory.close(); + } + + @SneakyThrows + public void testGetDocValuesWithByteVectorDataType() { + KNNVectorScriptDocValues scriptDocValues = getKNNByteVectorScriptDocValues(); + + scriptDocValues.setNextDocId(0); + Assert.assertArrayEquals(SAMPLE_FLOAT_VECTOR_DATA, scriptDocValues.getValue(), 0.1f); + + reader.close(); + directory.close(); + } + + @SneakyThrows + private KNNVectorScriptDocValues getKNNFloatVectorScriptDocValues() { + directory = newDirectory(); + createKNNFloatVectorDocument(directory); + reader = DirectoryReader.open(directory); + LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); + return new KNNVectorScriptDocValues( + leafReaderContext.reader().getBinaryDocValues(VectorDataTypeTests.MOCK_FLOAT_INDEX_FIELD_NAME), + VectorDataTypeTests.MOCK_FLOAT_INDEX_FIELD_NAME, + VectorDataType.FLOAT + ); + } + + @SneakyThrows + private KNNVectorScriptDocValues getKNNByteVectorScriptDocValues() { + directory = newDirectory(); + createKNNByteVectorDocument(directory); + reader = DirectoryReader.open(directory); + LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); + return new KNNVectorScriptDocValues( + leafReaderContext.reader().getBinaryDocValues(VectorDataTypeTests.MOCK_BYTE_INDEX_FIELD_NAME), + VectorDataTypeTests.MOCK_BYTE_INDEX_FIELD_NAME, + VectorDataType.BYTE + ); + } + + private void createKNNFloatVectorDocument(Directory directory) throws IOException { + IndexWriterConfig conf = newIndexWriterConfig(new MockAnalyzer(random())); + IndexWriter writer = new IndexWriter(directory, conf); + Document knnDocument = new Document(); + knnDocument.add( + new BinaryDocValuesField( + MOCK_FLOAT_INDEX_FIELD_NAME, + new VectorField(MOCK_FLOAT_INDEX_FIELD_NAME, SAMPLE_FLOAT_VECTOR_DATA, new FieldType()).binaryValue() + ) + ); + writer.addDocument(knnDocument); + writer.commit(); + writer.close(); + } + + private void createKNNByteVectorDocument(Directory directory) throws IOException { + IndexWriterConfig conf = newIndexWriterConfig(new MockAnalyzer(random())); + IndexWriter writer = new IndexWriter(directory, conf); + Document knnDocument = new Document(); + knnDocument.add( + new BinaryDocValuesField( + MOCK_BYTE_INDEX_FIELD_NAME, + new VectorField(MOCK_BYTE_INDEX_FIELD_NAME, SAMPLE_BYTE_VECTOR_DATA, new FieldType()).binaryValue() + ) + ); + writer.addDocument(knnDocument); + writer.commit(); + writer.close(); + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java index 6bfde31bb..6c7631216 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java @@ -69,6 +69,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.Version.CURRENT; +import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_CONSTRUCTION; import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_M; import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; @@ -338,7 +339,14 @@ public void testKnnVectorIndex( verify(perFieldKnnVectorsFormatSpy, atLeastOnce()).getKnnVectorsFormatForField(eq(FIELD_NAME_ONE)); IndexSearcher searcher = new IndexSearcher(reader); - Query query = KNNQueryFactory.create(KNNEngine.LUCENE, "dummy", FIELD_NAME_ONE, new float[] { 1.0f, 0.0f, 0.0f }, 1); + Query query = KNNQueryFactory.create( + KNNEngine.LUCENE, + "dummy", + FIELD_NAME_ONE, + new float[] { 1.0f, 0.0f, 0.0f }, + 1, + DEFAULT_VECTOR_DATA_TYPE_FIELD + ); assertEquals(1, searcher.count(query)); @@ -365,7 +373,14 @@ public void testKnnVectorIndex( verify(perFieldKnnVectorsFormatSpy, atLeastOnce()).getKnnVectorsFormatForField(eq(FIELD_NAME_TWO)); IndexSearcher searcher1 = new IndexSearcher(reader1); - Query query1 = KNNQueryFactory.create(KNNEngine.LUCENE, "dummy", FIELD_NAME_TWO, new float[] { 1.0f, 0.0f }, 1); + Query query1 = KNNQueryFactory.create( + KNNEngine.LUCENE, + "dummy", + FIELD_NAME_TWO, + new float[] { 1.0f, 0.0f }, + 1, + DEFAULT_VECTOR_DATA_TYPE_FIELD + ); assertEquals(1, searcher1.count(query1)); diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index d4a5b5aea..1f3598781 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -6,8 +6,11 @@ package org.opensearch.knn.index.mapper; import com.google.common.collect.ImmutableMap; +import lombok.SneakyThrows; +import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnVectorField; import org.apache.lucene.index.IndexableField; +import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.util.BytesRef; import org.mockito.Mockito; import org.opensearch.common.Explicit; @@ -28,6 +31,7 @@ import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.VectorField; import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; import org.opensearch.knn.index.util.KNNEngine; @@ -42,10 +46,13 @@ import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.Locale; import java.util.Optional; +import java.util.stream.Collectors; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doReturn; +import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; @@ -60,6 +67,7 @@ import static org.opensearch.Version.CURRENT; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; public class KNNVectorFieldMapperTests extends KNNTestCase { @@ -71,9 +79,13 @@ public class KNNVectorFieldMapperTests extends KNNTestCase { private final static float[] TEST_VECTOR = createInitializedFloatArray(TEST_DIMENSION, TEST_VECTOR_VALUE); + private final static byte TEST_BYTE_VECTOR_VALUE = 10; + private final static byte[] TEST_BYTE_VECTOR = createInitializedByteArray(TEST_DIMENSION, TEST_BYTE_VECTOR_VALUE); + private final static BytesRef TEST_VECTOR_BYTES_REF = new BytesRef( KNNVectorSerializerFactory.getDefaultSerializer().floatToByteArray(TEST_VECTOR) ); + private final static BytesRef TEST_BYTE_VECTOR_BYTES_REF = new BytesRef(TEST_BYTE_VECTOR); private static final String DIMENSION_FIELD_NAME = "dimension"; private static final String KNN_VECTOR_TYPE = "knn_vector"; private static final String TYPE_FIELD_NAME = "type"; @@ -82,7 +94,11 @@ public void testBuilder_getParameters() { String fieldName = "test-field-name"; ModelDao modelDao = mock(ModelDao.class); KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder(fieldName, modelDao); - assertEquals(6, builder.getParameters().size()); + + assertEquals(7, builder.getParameters().size()); + List actualParams = builder.getParameters().stream().map(a -> a.name).collect(Collectors.toList()); + List expectedParams = Arrays.asList("store", "doc_values", DIMENSION, VECTOR_DATA_TYPE_FIELD, "meta", KNN_METHOD, MODEL_ID); + assertEquals(expectedParams, actualParams); } public void testBuilder_build_fromKnnMethodContext() { @@ -334,6 +350,56 @@ public void testTypeParser_parse_fromKnnMethodContext_invalidDimension() throws ); } + // Validate TypeParser parsing invalid vector data_type which throws exception + @SneakyThrows + public void testTypeParser_parse_invalidVectorDataType() { + String fieldName = "test-field-name-vec"; + String indexName = "test-index-name-vec"; + String vectorDataType = "invalid"; + String supportedTypes = String.join( + ",", + Arrays.stream((VectorDataType.values())).map(VectorDataType::getValue).collect(Collectors.toCollection(HashSet::new)) + ); + + Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); + + ModelDao modelDao = mock(ModelDao.class); + KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao); + + XContentBuilder xContentBuilderOverInvalidVectorType = XContentFactory.jsonBuilder() + .startObject() + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION, 10) + .field(VECTOR_DATA_TYPE_FIELD, vectorDataType) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2) + .field(KNN_ENGINE, LUCENE_NAME) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, 128) + .endObject() + .endObject() + .endObject(); + + IllegalArgumentException ex = expectThrows( + IllegalArgumentException.class, + () -> typeParser.parse( + fieldName, + xContentBuilderToMap(xContentBuilderOverInvalidVectorType), + buildParserContext(indexName, settings) + ) + ); + assertEquals( + String.format( + Locale.ROOT, + "Invalid value provided for [%s] field. Supported values are [%s]", + VECTOR_DATA_TYPE_FIELD, + supportedTypes + ), + ex.getMessage() + ); + } + public void testTypeParser_parse_fromKnnMethodContext_invalidSpaceType() throws IOException { String fieldName = "test-field-name"; String indexName = "test-index-name"; @@ -673,30 +739,11 @@ public void testKNNVectorFieldMapper_merge_fromModel() throws IOException { expectThrows(IllegalArgumentException.class, () -> knnVectorFieldMapper1.merge(knnVectorFieldMapper3)); } - public void testLuceneFieldMapper_parseCreateField_docValues() throws IOException { + @SneakyThrows + public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { // Create a lucene field mapper that creates a binary doc values field as well as KnnVectorField - KNNMethodContext knnMethodContext = new KNNMethodContext( - KNNEngine.LUCENE, - SpaceType.DEFAULT, - new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()) - ); - - KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = new KNNVectorFieldMapper.KNNVectorFieldType( - TEST_FIELD_NAME, - Collections.emptyMap(), - TEST_DIMENSION, - knnMethodContext - ); - LuceneFieldMapper.CreateLuceneFieldMapperInput.CreateLuceneFieldMapperInputBuilder inputBuilder = - LuceneFieldMapper.CreateLuceneFieldMapperInput.builder() - .name(TEST_FIELD_NAME) - .mappedFieldType(knnVectorFieldType) - .multiFields(FieldMapper.MultiFields.empty()) - .copyTo(FieldMapper.CopyTo.empty()) - .hasDocValues(true) - .ignoreMalformed(new Explicit<>(true, true)) - .knnMethodContext(knnMethodContext); + createLuceneFieldMapperInputBuilder(VectorDataType.FLOAT); ParseContext.Document document = new ParseContext.Document(); ContentPath contentPath = new ContentPath(); @@ -731,6 +778,7 @@ public void testLuceneFieldMapper_parseCreateField_docValues() throws IOExceptio } assertEquals(TEST_VECTOR_BYTES_REF, vectorField.binaryValue()); + assertEquals(VectorEncoding.FLOAT32, vectorField.fieldType().vectorEncoding()); assertArrayEquals(TEST_VECTOR, knnVectorField.vectorValue(), 0.001f); // Test when doc values are disabled @@ -757,12 +805,112 @@ public void testLuceneFieldMapper_parseCreateField_docValues() throws IOExceptio assertArrayEquals(TEST_VECTOR, knnVectorField.vectorValue(), 0.001f); } - public static float[] createInitializedFloatArray(int dimension, float value) { + @SneakyThrows + public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { + // Create a lucene field mapper that creates a binary doc values field as well as KnnByteVectorField + + LuceneFieldMapper.CreateLuceneFieldMapperInput.CreateLuceneFieldMapperInputBuilder inputBuilder = + createLuceneFieldMapperInputBuilder(VectorDataType.BYTE); + + ParseContext.Document document = new ParseContext.Document(); + ContentPath contentPath = new ContentPath(); + ParseContext parseContext = mock(ParseContext.class); + when(parseContext.doc()).thenReturn(document); + when(parseContext.path()).thenReturn(contentPath); + + LuceneFieldMapper luceneFieldMapper = Mockito.spy(new LuceneFieldMapper(inputBuilder.build())); + doReturn(Optional.of(TEST_BYTE_VECTOR)).when(luceneFieldMapper).getBytesFromContext(parseContext, TEST_DIMENSION); + doNothing().when(luceneFieldMapper).validateIfCircuitBreakerIsNotTriggered(); + doNothing().when(luceneFieldMapper).validateIfKNNPluginEnabled(); + + luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION); + + // Document should have 2 fields: one for VectorField (binary doc values) and one for KnnByteVectorField + List fields = document.getFields(); + assertEquals(2, fields.size()); + IndexableField field1 = fields.get(0); + IndexableField field2 = fields.get(1); + + VectorField vectorField; + KnnByteVectorField knnByteVectorField; + if (field1 instanceof VectorField) { + assertTrue(field2 instanceof KnnByteVectorField); + vectorField = (VectorField) field1; + knnByteVectorField = (KnnByteVectorField) field2; + } else { + assertTrue(field1 instanceof KnnByteVectorField); + assertTrue(field2 instanceof VectorField); + knnByteVectorField = (KnnByteVectorField) field1; + vectorField = (VectorField) field2; + } + + assertEquals(TEST_BYTE_VECTOR_BYTES_REF, vectorField.binaryValue()); + assertArrayEquals(TEST_BYTE_VECTOR, knnByteVectorField.vectorValue()); + + // Test when doc values are disabled + document = new ParseContext.Document(); + contentPath = new ContentPath(); + parseContext = mock(ParseContext.class); + when(parseContext.doc()).thenReturn(document); + when(parseContext.path()).thenReturn(contentPath); + + inputBuilder.hasDocValues(false); + luceneFieldMapper = Mockito.spy(new LuceneFieldMapper(inputBuilder.build())); + doReturn(Optional.of(TEST_BYTE_VECTOR)).when(luceneFieldMapper).getBytesFromContext(parseContext, TEST_DIMENSION); + doNothing().when(luceneFieldMapper).validateIfCircuitBreakerIsNotTriggered(); + doNothing().when(luceneFieldMapper).validateIfKNNPluginEnabled(); + + luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION); + + // Document should have 1 field: one for KnnByteVectorField + fields = document.getFields(); + assertEquals(1, fields.size()); + IndexableField field = fields.get(0); + assertTrue(field instanceof KnnByteVectorField); + knnByteVectorField = (KnnByteVectorField) field; + assertArrayEquals(TEST_BYTE_VECTOR, knnByteVectorField.vectorValue()); + } + + private LuceneFieldMapper.CreateLuceneFieldMapperInput.CreateLuceneFieldMapperInputBuilder createLuceneFieldMapperInputBuilder( + VectorDataType vectorDataType + ) { + KNNMethodContext knnMethodContext = new KNNMethodContext( + KNNEngine.LUCENE, + SpaceType.DEFAULT, + new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()) + ); + + KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = new KNNVectorFieldMapper.KNNVectorFieldType( + TEST_FIELD_NAME, + Collections.emptyMap(), + TEST_DIMENSION, + knnMethodContext, + vectorDataType + ); + + return LuceneFieldMapper.CreateLuceneFieldMapperInput.builder() + .name(TEST_FIELD_NAME) + .mappedFieldType(knnVectorFieldType) + .multiFields(FieldMapper.MultiFields.empty()) + .copyTo(FieldMapper.CopyTo.empty()) + .hasDocValues(true) + .vectorDataType(vectorDataType) + .ignoreMalformed(new Explicit<>(true, true)) + .knnMethodContext(knnMethodContext); + } + + private static float[] createInitializedFloatArray(int dimension, float value) { float[] array = new float[dimension]; Arrays.fill(array, value); return array; } + private static byte[] createInitializedByteArray(int dimension, byte value) { + byte[] array = new byte[dimension]; + Arrays.fill(array, value); + return array; + } + public IndexMetadata buildIndexMetaData(String indexName, Settings settings) { return IndexMetadata.builder(indexName) .settings(settings) diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 74d99a805..e97cee611 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -30,6 +30,7 @@ import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelDao; @@ -175,6 +176,7 @@ public void testDoToQuery_Normal() throws Exception { KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); assertEquals(knnQueryBuilder.getK(), query.getK()); @@ -190,6 +192,7 @@ public void testDoToQuery_KnnQueryWithFilter() throws Exception { KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); @@ -241,6 +244,7 @@ public void testDoToQuery_FromModel() { // Dimension is -1. In this case, model metadata will need to provide dimension when(mockKNNVectorField.getDimension()).thenReturn(-K); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockKNNVectorField.getKnnMethodContext()).thenReturn(null); String modelId = "test-model-id"; when(mockKNNVectorField.getModelId()).thenReturn(modelId); diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java index 674d1be39..4dccfd087 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java @@ -24,6 +24,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; public class KNNQueryFactoryTests extends KNNTestCase { private static final String FILTER_FILED_NAME = "foo"; @@ -38,7 +39,14 @@ public class KNNQueryFactoryTests extends KNNTestCase { public void testCreateCustomKNNQuery() { for (KNNEngine knnEngine : KNNEngine.getEnginesThatCreateCustomSegmentFiles()) { - Query query = KNNQueryFactory.create(knnEngine, testIndexName, testFieldName, testQueryVector, testK); + Query query = KNNQueryFactory.create( + knnEngine, + testIndexName, + testFieldName, + testQueryVector, + testK, + DEFAULT_VECTOR_DATA_TYPE_FIELD + ); assertTrue(query instanceof KNNQuery); assertEquals(testIndexName, ((KNNQuery) query).getIndexName()); @@ -53,7 +61,14 @@ public void testCreateLuceneDefaultQuery() { .filter(knnEngine -> !KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)) .collect(Collectors.toList()); for (KNNEngine knnEngine : luceneDefaultQueryEngineList) { - Query query = KNNQueryFactory.create(knnEngine, testIndexName, testFieldName, testQueryVector, testK); + Query query = KNNQueryFactory.create( + knnEngine, + testIndexName, + testFieldName, + testQueryVector, + testK, + DEFAULT_VECTOR_DATA_TYPE_FIELD + ); assertTrue(query.getClass().isAssignableFrom(KnnFloatVectorQuery.class)); } } @@ -71,6 +86,7 @@ public void testCreateLuceneQueryWithFilter() { .indexName(testIndexName) .fieldName(testFieldName) .vector(testQueryVector) + .vectorDataType(DEFAULT_VECTOR_DATA_TYPE_FIELD) .k(testK) .context(mockQueryShardContext) .filter(FILTER_QUERY_BUILDER) diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java index 92fd56e45..b5bc4b95f 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java @@ -6,6 +6,7 @@ package org.opensearch.knn.plugin.script; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.index.mapper.BinaryFieldMapper; import org.opensearch.index.mapper.NumberFieldMapper; @@ -64,11 +65,14 @@ public void testParseKNNVectorQuery() { KNNVectorFieldMapper.KNNVectorFieldType fieldType = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); when(fieldType.getDimension()).thenReturn(3); - assertArrayEquals(arrayFloat, KNNScoringSpaceUtil.parseToFloatArray(arrayListQueryObject, 3), 0.1f); + assertArrayEquals(arrayFloat, KNNScoringSpaceUtil.parseToFloatArray(arrayListQueryObject, 3, VectorDataType.FLOAT), 0.1f); - expectThrows(IllegalStateException.class, () -> KNNScoringSpaceUtil.parseToFloatArray(arrayListQueryObject, 4)); + expectThrows( + IllegalStateException.class, + () -> KNNScoringSpaceUtil.parseToFloatArray(arrayListQueryObject, 4, VectorDataType.FLOAT) + ); String invalidObject = "invalidObject"; - expectThrows(ClassCastException.class, () -> KNNScoringSpaceUtil.parseToFloatArray(invalidObject, 3)); + expectThrows(ClassCastException.class, () -> KNNScoringSpaceUtil.parseToFloatArray(invalidObject, 3, VectorDataType.FLOAT)); } } diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java index 49add790e..4a2bb7254 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java @@ -7,6 +7,7 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.KNNVectorScriptDocValues; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.VectorField; import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.document.BinaryDocValuesField; @@ -81,7 +82,7 @@ public void testGetInvalidVectorMagnitudeSquared() { public void testConvertInvalidVectorToPrimitive() { float[] primitiveVector = null; - assertEquals(primitiveVector, KNNScoringSpaceUtil.convertVectorToPrimitive(primitiveVector)); + assertEquals(primitiveVector, KNNScoringSpaceUtil.convertVectorToPrimitive(primitiveVector, VectorDataType.FLOAT)); } public void testCosineSimilQueryVectorZeroMagnitude() { @@ -243,7 +244,11 @@ public KNNVectorScriptDocValues getScriptDocValues(String fieldName) throws IOEx if (scriptDocValues == null) { reader = DirectoryReader.open(directory); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - scriptDocValues = new KNNVectorScriptDocValues(leafReaderContext.reader().getBinaryDocValues(fieldName), fieldName); + scriptDocValues = new KNNVectorScriptDocValues( + leafReaderContext.reader().getBinaryDocValues(fieldName), + fieldName, + VectorDataType.FLOAT + ); } return scriptDocValues; }