Skip to content

Commit

Permalink
Add Support for Lucene Byte Sized Vector (#971) (#974)
Browse files Browse the repository at this point in the history
  • Loading branch information
naveentatikonda authored Jul 12, 2023
1 parent 86b812c commit 581fadd
Show file tree
Hide file tree
Showing 27 changed files with 1,598 additions and 189 deletions.
3 changes: 2 additions & 1 deletion release-notes/opensearch-knn.release-notes-2.9.0.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
Compatible with OpenSearch 2.9.0

### Features
Added support for Efficient Pre-filtering for Faiss Engine ([#936](https://github.com/opensearch-project/k-NN/pull/936))
* Added support for Efficient Pre-filtering for Faiss Engine ([#936](https://github.com/opensearch-project/k-NN/pull/936))
* Add Support for Lucene Byte Sized Vector ([#971](https://github.com/opensearch-project/k-NN/pull/971))
5 changes: 5 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.knn.common;

import org.opensearch.knn.index.VectorDataType;

public class KNNConstants {
// shared across library constants
public static final String DIMENSION = "dimension";
Expand Down Expand Up @@ -50,6 +52,9 @@ public class KNNConstants {
public static final String MAX_VECTOR_COUNT_PARAMETER = "max_training_vector_count";
public static final String SEARCH_SIZE_PARAMETER = "search_size";

public static final String VECTOR_DATA_TYPE_FIELD = "data_type";
public static final VectorDataType DEFAULT_VECTOR_DATA_TYPE_FIELD = VectorDataType.FLOAT;

// Lucene specific constants
public static final String LUCENE_NAME = "lucene";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ public class KNNVectorDVLeafFieldData implements LeafFieldData {

private final LeafReader reader;
private final String fieldName;
private final VectorDataType vectorDataType;

public KNNVectorDVLeafFieldData(LeafReader reader, String fieldName) {
public KNNVectorDVLeafFieldData(LeafReader reader, String fieldName, VectorDataType vectorDataType) {
this.reader = reader;
this.fieldName = fieldName;
this.vectorDataType = vectorDataType;
}

@Override
Expand All @@ -38,7 +40,7 @@ public long ramBytesUsed() {
public ScriptDocValues<float[]> getScriptValues() {
try {
BinaryDocValues values = DocValues.getBinary(reader, fieldName);
return new KNNVectorScriptDocValues(values, fieldName);
return new KNNVectorScriptDocValues(values, fieldName, vectorDataType);
} catch (IOException e) {
throw new IllegalStateException("Cannot load doc values for knn vector field: " + fieldName, e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ public class KNNVectorIndexFieldData implements IndexFieldData<KNNVectorDVLeafFi

private final String fieldName;
private final ValuesSourceType valuesSourceType;
private final VectorDataType vectorDataType;

public KNNVectorIndexFieldData(String fieldName, ValuesSourceType valuesSourceType) {
public KNNVectorIndexFieldData(String fieldName, ValuesSourceType valuesSourceType, VectorDataType vectorDataType) {
this.fieldName = fieldName;
this.valuesSourceType = valuesSourceType;
this.vectorDataType = vectorDataType;
}

@Override
Expand All @@ -39,7 +41,7 @@ public ValuesSourceType getValuesSourceType() {

@Override
public KNNVectorDVLeafFieldData load(LeafReaderContext context) {
return new KNNVectorDVLeafFieldData(context.reader(), fieldName);
return new KNNVectorDVLeafFieldData(context.reader(), fieldName, vectorDataType);
}

@Override
Expand Down Expand Up @@ -70,15 +72,17 @@ public static class Builder implements IndexFieldData.Builder {

private final String name;
private final ValuesSourceType valuesSourceType;
private final VectorDataType vectorDataType;

public Builder(String name, ValuesSourceType valuesSourceType) {
public Builder(String name, ValuesSourceType valuesSourceType, VectorDataType vectorDataType) {
this.name = name;
this.valuesSourceType = valuesSourceType;
this.vectorDataType = vectorDataType;
}

@Override
public IndexFieldData<?> build(IndexFieldDataCache cache, CircuitBreakerService breakerService) {
return new KNNVectorIndexFieldData(name, valuesSourceType);
return new KNNVectorIndexFieldData(name, valuesSourceType, vectorDataType);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,22 @@

package org.opensearch.knn.index;

import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.util.BytesRef;
import org.opensearch.ExceptionsHelper;
import org.opensearch.index.fielddata.ScriptDocValues;
import org.opensearch.knn.index.codec.util.KNNVectorSerializer;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;

import java.io.ByteArrayInputStream;
import java.io.IOException;

@RequiredArgsConstructor
public final class KNNVectorScriptDocValues extends ScriptDocValues<float[]> {

private final BinaryDocValues binaryDocValues;
private final String fieldName;
private boolean docExists;

public KNNVectorScriptDocValues(BinaryDocValues binaryDocValues, String fieldName) {
this.binaryDocValues = binaryDocValues;
this.fieldName = fieldName;
}
@Getter
private final VectorDataType vectorDataType;
private boolean docExists = false;

@Override
public void setNextDocId(int docId) throws IOException {
Expand All @@ -47,11 +43,7 @@ public float[] getValue() {
throw new IllegalStateException(errorMessage);
}
try {
BytesRef value = binaryDocValues.binaryValue();
ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length);
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
final float[] vector = vectorSerializer.byteToFloatArray(byteStream);
return vector;
return vectorDataType.getVectorFromDocValues(binaryDocValues.binaryValue());
} catch (IOException e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
}
Expand Down
120 changes: 120 additions & 0 deletions src/main/java/org/opensearch/knn/index/VectorDataType.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index;

import lombok.AllArgsConstructor;
import lombok.Getter;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.BytesRef;
import org.opensearch.knn.index.codec.util.KNNVectorSerializer;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;

import java.io.ByteArrayInputStream;
import java.util.Arrays;
import java.util.Locale;
import java.util.Objects;
import java.util.stream.Collectors;

import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;

/**
* Enum contains data_type of vectors and right now only supported for lucene engine in k-NN plugin.
* We have two vector data_types, one is float (default) and the other one is byte.
*/
@AllArgsConstructor
public enum VectorDataType {
BYTE("byte") {

@Override
public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) {
return KnnByteVectorField.createFieldType(dimension, vectorSimilarityFunction);
}

@Override
public float[] getVectorFromDocValues(BytesRef binaryValue) {
float[] vector = new float[binaryValue.length];
int i = 0;
int j = binaryValue.offset;

while (i < binaryValue.length) {
vector[i++] = binaryValue.bytes[j++];
}
return vector;
}
},
FLOAT("float") {

@Override
public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) {
return KnnVectorField.createFieldType(dimension, vectorSimilarityFunction);
}

@Override
public float[] getVectorFromDocValues(BytesRef binaryValue) {
ByteArrayInputStream byteStream = new ByteArrayInputStream(binaryValue.bytes, binaryValue.offset, binaryValue.length);
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
return vectorSerializer.byteToFloatArray(byteStream);
}

};

public static final String SUPPORTED_VECTOR_DATA_TYPES = Arrays.stream(VectorDataType.values())
.map(VectorDataType::getValue)
.collect(Collectors.joining(","));
@Getter
private final String value;

/**
* Creates a KnnVectorFieldType based on the VectorDataType using the provided dimension and
* VectorSimilarityFunction.
*
* @param dimension Dimension of the vector
* @param vectorSimilarityFunction VectorSimilarityFunction for a given spaceType
* @return FieldType
*/
public abstract FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction);

/**
* Deserializes float vector from doc values binary value.
*
* @param binaryValue Binary Value of DocValues
* @return float vector deserialized from binary value
*/
public abstract float[] getVectorFromDocValues(BytesRef binaryValue);

/**
* Validates if given VectorDataType is in the list of supported data types.
* @param vectorDataType VectorDataType
* @return the same VectorDataType if it is in the supported values
* throws Exception if an invalid value is provided.
*/
public static VectorDataType get(String vectorDataType) {
Objects.requireNonNull(
vectorDataType,
String.format(
Locale.ROOT,
"[%s] should not be null. Supported types are [%s]",
VECTOR_DATA_TYPE_FIELD,
SUPPORTED_VECTOR_DATA_TYPES
)
);
try {
return VectorDataType.valueOf(vectorDataType.toUpperCase(Locale.ROOT));
} catch (Exception e) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"Invalid value provided for [%s] field. Supported values are [%s]",
VECTOR_DATA_TYPE_FIELD,
SUPPORTED_VECTOR_DATA_TYPES
)
);
}
}
}
15 changes: 15 additions & 0 deletions src/main/java/org/opensearch/knn/index/VectorField.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,19 @@ public VectorField(String name, float[] value, IndexableFieldType type) {
throw new RuntimeException(e);
}
}

/**
* @param name FieldType name
* @param value an array of byte vector values
* @param type FieldType to build DocValues
*/
public VectorField(String name, byte[] value, IndexableFieldType type) {
super(name, new BytesRef(), type);
try {
this.setBytesValue(value);
} catch (Exception e) {
throw new RuntimeException(e);
}

}
}
Loading

0 comments on commit 581fadd

Please sign in to comment.