Skip to content

Commit

Permalink
Fix lucene codec after lucene version bumped to 9.12
Browse files Browse the repository at this point in the history
Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v committed Oct 8, 2024
1 parent 2c170fb commit 8baeaba
Show file tree
Hide file tree
Showing 10 changed files with 240 additions and 34 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Documentation
* Fix sed command in DEVELOPER_GUIDE.md to append a new line character '\n'. [#2181](https://github.com/opensearch-project/k-NN/pull/2181)
### Maintenance
* Fix lucene codec after lucene version bumped to 9.12. [#2195](https://github.com/opensearch-project/k-NN/pull/2195)
### Refactoring
* Does not create additional KNNVectorValues in NativeEngines990KNNVectorWriter when quantization is not needed [#2133](https://github.com/opensearch-project/k-NN/pull/2133)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN9120Codec;

import lombok.Builder;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.CompoundFormat;
import org.apache.lucene.codecs.DocValuesFormat;
import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.knn.index.codec.KNNCodecVersion;
import org.opensearch.knn.index.codec.KNNFormatFacade;

/**
* KNN Codec that wraps the Lucene Codec which is part of Lucene 9.12
*/
public class KNN9120Codec extends FilterCodec {
private static final KNNCodecVersion VERSION = KNNCodecVersion.V_9_12_0;
private final KNNFormatFacade knnFormatFacade;
private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat;

/**
* No arg constructor that uses Lucene99 as the delegate
*/
public KNN9120Codec() {
this(VERSION.getDefaultCodecDelegate(), VERSION.getPerFieldKnnVectorsFormat());
}

/**
* Sole constructor. When subclassing this codec, create a no-arg ctor and pass the delegate codec
* and a unique name to this ctor.
*
* @param delegate codec that will perform all operations this codec does not override
* @param knnVectorsFormat per field format for KnnVector
*/
@Builder
protected KNN9120Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat) {
super(VERSION.getCodecName(), delegate);
knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate);
perFieldKnnVectorsFormat = knnVectorsFormat;
}

@Override
public DocValuesFormat docValuesFormat() {
return knnFormatFacade.docValuesFormat();
}

@Override
public CompoundFormat compoundFormat() {
return knnFormatFacade.compoundFormat();
}

@Override
public KnnVectorsFormat knnVectorsFormat() {
return perFieldKnnVectorsFormat;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@

import lombok.Getter;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.RamUsageEstimator;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

Expand All @@ -44,22 +46,37 @@ class NativeEngineFieldVectorsWriter<T> extends KnnFieldVectorsWriter<T> {
@Getter
private final DocsWithFieldSet docsWithField;
private final InfoStream infoStream;
private final FlatFieldVectorsWriter<T> flatFieldVectorsWriter;

static NativeEngineFieldVectorsWriter<?> create(final FieldInfo fieldInfo, final InfoStream infoStream) {
@SuppressWarnings("unchecked")
static NativeEngineFieldVectorsWriter<?> create(
final FieldInfo fieldInfo,
final FlatFieldVectorsWriter<?> flatFieldVectorsWriter,
final InfoStream infoStream
) {
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32:
return new NativeEngineFieldVectorsWriter<float[]>(fieldInfo, infoStream);
return new NativeEngineFieldVectorsWriter<>(
fieldInfo,
(FlatFieldVectorsWriter<float[]>) flatFieldVectorsWriter,
infoStream
);
case BYTE:
return new NativeEngineFieldVectorsWriter<byte[]>(fieldInfo, infoStream);
return new NativeEngineFieldVectorsWriter<>(fieldInfo, (FlatFieldVectorsWriter<byte[]>) flatFieldVectorsWriter, infoStream);
}
throw new IllegalStateException("Unsupported Vector encoding : " + fieldInfo.getVectorEncoding());
}

private NativeEngineFieldVectorsWriter(final FieldInfo fieldInfo, final InfoStream infoStream) {
private NativeEngineFieldVectorsWriter(
final FieldInfo fieldInfo,
final FlatFieldVectorsWriter<T> flatFieldVectorsWriter,
final InfoStream infoStream
) {
this.fieldInfo = fieldInfo;
this.infoStream = infoStream;
vectors = new HashMap<>();
this.docsWithField = new DocsWithFieldSet();
this.flatFieldVectorsWriter = flatFieldVectorsWriter;
}

/**
Expand All @@ -70,7 +87,7 @@ private NativeEngineFieldVectorsWriter(final FieldInfo fieldInfo, final InfoStre
* @param vectorValue T
*/
@Override
public void addValue(int docID, T vectorValue) {
public void addValue(int docID, T vectorValue) throws IOException {
if (docID == lastDocID) {
throw new IllegalArgumentException(
"[NativeEngineKNNVectorWriter]VectorValuesField \""
Expand All @@ -81,6 +98,8 @@ public void addValue(int docID, T vectorValue) {
// TODO: we can build the graph here too iteratively. but right now I am skipping that as we need iterative
// graph build support on the JNI layer.
assert docID > lastDocID;
// ensuring that vector is provided to flatFieldWriter.
flatFieldVectorsWriter.addValue(docID, vectorValue);
vectors.put(docID, vectorValue);
docsWithField.add(docID);
lastDocID = docID;
Expand All @@ -105,6 +124,7 @@ public long ramBytesUsed() {
return SHALLOW_SIZE + docsWithField.ramBytesUsed() + (long) this.vectors.size() * (long) (RamUsageEstimator.NUM_BYTES_OBJECT_REF
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER) + (long) this.vectors.size() * RamUsageEstimator.shallowSizeOfInstance(
Integer.class
) + (long) vectors.size() * fieldInfo.getVectorDimension() * fieldInfo.getVectorEncoding().byteSize;
) + (long) vectors.size() * fieldInfo.getVectorDimension() * fieldInfo.getVectorEncoding().byteSize + flatFieldVectorsWriter
.ramBytesUsed();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,13 @@ public NativeEngines990KnnVectorsWriter(SegmentWriteState segmentWriteState, Fla
*/
@Override
public KnnFieldVectorsWriter<?> addField(final FieldInfo fieldInfo) throws IOException {
final NativeEngineFieldVectorsWriter<?> newField = NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream);
final NativeEngineFieldVectorsWriter<?> newField = NativeEngineFieldVectorsWriter.create(
fieldInfo,
flatVectorsWriter.addField(fieldInfo),
segmentWriteState.infoStream
);
fields.add(newField);
return flatVectorsWriter.addField(fieldInfo, newField);
return newField;
}

/**
Expand Down
21 changes: 19 additions & 2 deletions src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
import org.apache.lucene.backward_codecs.lucene94.Lucene94Codec;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.backward_codecs.lucene95.Lucene95Codec;
import org.apache.lucene.codecs.lucene99.Lucene99Codec;
import org.apache.lucene.backward_codecs.lucene99.Lucene99Codec;
import org.apache.lucene.codecs.lucene912.Lucene912Codec;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.index.codec.KNN80Codec.KNN80CompoundFormat;
import org.opensearch.knn.index.codec.KNN80Codec.KNN80DocValuesFormat;
import org.opensearch.knn.index.codec.KNN910Codec.KNN910Codec;
import org.opensearch.knn.index.codec.KNN9120Codec.KNN9120Codec;
import org.opensearch.knn.index.codec.KNN920Codec.KNN920Codec;
import org.opensearch.knn.index.codec.KNN920Codec.KNN920PerFieldKnnVectorsFormat;
import org.opensearch.knn.index.codec.KNN940Codec.KNN940Codec;
Expand Down Expand Up @@ -110,9 +112,24 @@ public enum KNNCodecVersion {
.knnVectorsFormat(new KNN990PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService)))
.build(),
KNN990Codec::new
),

V_9_12_0(
"KNN990Codec",
new Lucene912Codec(),
new KNN990PerFieldKnnVectorsFormat(Optional.empty()),
(delegate) -> new KNNFormatFacade(
new KNN80DocValuesFormat(delegate.docValuesFormat()),
new KNN80CompoundFormat(delegate.compoundFormat())
),
(userCodec, mapperService) -> KNN9120Codec.builder()
.delegate(userCodec)
.knnVectorsFormat(new KNN990PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService)))
.build(),
KNN9120Codec::new
);

private static final KNNCodecVersion CURRENT = V_9_9_0;
private static final KNNCodecVersion CURRENT = V_9_12_0;

private final String codecName;
private final Codec defaultCodecDelegate;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ public KNNScalarQuantizedVectorsFormatParams(Map<String, Object> params, int def
Map<String, Object> sqEncoderParams = encoderMethodComponentContext.getParameters();
this.initConfidenceInterval(sqEncoderParams);
this.initBits(sqEncoderParams);
this.initCompressFlag();
// compression flag should be set after bits has been initialised as compressionFlag depends on bits.
this.setCompressionFlag();
}

@Override
Expand Down Expand Up @@ -76,7 +77,14 @@ private void initBits(final Map<String, Object> params) {
this.bits = LUCENE_SQ_DEFAULT_BITS;
}

private void initCompressFlag() {
this.compressFlag = true;
private void setCompressionFlag() {
if (this.bits <= 0) {
throw new IllegalArgumentException(
"Either bits are set to less than 0 or they have not been initialized." + " Bit value: " + this.bits
);
}
// This check is coming from Lucene. Code ref:
// https://github.com/apache/lucene/blob/branch_9_12/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java#L113-L116
this.compressFlag = this.bits <= 4;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

package org.opensearch.knn.index.codec.KNN990Codec;

import lombok.SneakyThrows;
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.util.InfoStream;
Expand All @@ -21,82 +23,109 @@
public class NativeEngineFieldVectorsWriterTests extends KNNCodecTestCase {

@SuppressWarnings("unchecked")
@SneakyThrows
public void testCreate_ForDifferentInputs_thenSuccess() {
final FieldInfo fieldInfo = Mockito.mock(FieldInfo.class);
Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.FLOAT32);
final FlatFieldVectorsWriter<float[]> mockedFlatFieldVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class);
NativeEngineFieldVectorsWriter<float[]> floatWriter = (NativeEngineFieldVectorsWriter<float[]>) NativeEngineFieldVectorsWriter
.create(fieldInfo, InfoStream.getDefault());
floatWriter.addValue(1, new float[] { 1.0f, 2.0f });
.create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault());
final float[] floatVector = new float[] { 1.0f, 2.0f };
floatWriter.addValue(1, floatVector);
Mockito.doNothing().when(mockedFlatFieldVectorsWriter).addValue(1, floatVector);

Mockito.verify(fieldInfo).getVectorEncoding();
Mockito.verify(mockedFlatFieldVectorsWriter).addValue(1, floatVector);

final byte[] byteVector = new byte[] { 1, 2 };
final FlatFieldVectorsWriter<byte[]> mockedFlatFieldByteVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class);
Mockito.doNothing().when(mockedFlatFieldByteVectorsWriter).addValue(1, byteVector);
Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.BYTE);
NativeEngineFieldVectorsWriter<byte[]> byteWriter = (NativeEngineFieldVectorsWriter<byte[]>) NativeEngineFieldVectorsWriter.create(
fieldInfo,
mockedFlatFieldByteVectorsWriter,
InfoStream.getDefault()
);
Assert.assertNotNull(byteWriter);
Mockito.verify(fieldInfo, Mockito.times(2)).getVectorEncoding();
byteWriter.addValue(1, new byte[] { 1, 2 });
byteWriter.addValue(1, byteVector);
Mockito.verify(mockedFlatFieldByteVectorsWriter).addValue(1, byteVector);
}

@SuppressWarnings("unchecked")
@SneakyThrows
public void testAddValue_ForDifferentInputs_thenSuccess() {
final FieldInfo fieldInfo = Mockito.mock(FieldInfo.class);
Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.FLOAT32);
final NativeEngineFieldVectorsWriter<float[]> floatWriter = (NativeEngineFieldVectorsWriter<float[]>) NativeEngineFieldVectorsWriter
.create(fieldInfo, InfoStream.getDefault());
final FlatFieldVectorsWriter<float[]> mockedFlatFieldVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class);
final float[] vec1 = new float[] { 1.0f, 2.0f };
final float[] vec2 = new float[] { 2.0f, 2.0f };
Mockito.doNothing().when(mockedFlatFieldVectorsWriter).addValue(1, vec1);
Mockito.doNothing().when(mockedFlatFieldVectorsWriter).addValue(2, vec2);
final NativeEngineFieldVectorsWriter<float[]> floatWriter = (NativeEngineFieldVectorsWriter<float[]>) NativeEngineFieldVectorsWriter
.create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault());
floatWriter.addValue(1, vec1);
floatWriter.addValue(2, vec2);
Mockito.verify(mockedFlatFieldVectorsWriter).addValue(1, vec1);
Mockito.verify(mockedFlatFieldVectorsWriter).addValue(2, vec2);

Assert.assertEquals(vec1, floatWriter.getVectors().get(1));
Assert.assertEquals(vec2, floatWriter.getVectors().get(2));
Mockito.verify(fieldInfo).getVectorEncoding();

Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.BYTE);
final NativeEngineFieldVectorsWriter<byte[]> byteWriter = (NativeEngineFieldVectorsWriter<byte[]>) NativeEngineFieldVectorsWriter
.create(fieldInfo, InfoStream.getDefault());
final FlatFieldVectorsWriter<byte[]> mockedFlatFieldByteVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class);
final byte[] bvec1 = new byte[] { 1, 2 };
final byte[] bvec2 = new byte[] { 2, 2 };
Mockito.doNothing().when(mockedFlatFieldByteVectorsWriter).addValue(1, bvec1);
Mockito.doNothing().when(mockedFlatFieldByteVectorsWriter).addValue(2, bvec2);
final NativeEngineFieldVectorsWriter<byte[]> byteWriter = (NativeEngineFieldVectorsWriter<byte[]>) NativeEngineFieldVectorsWriter
.create(fieldInfo, mockedFlatFieldByteVectorsWriter, InfoStream.getDefault());
byteWriter.addValue(1, bvec1);
byteWriter.addValue(2, bvec2);

Assert.assertEquals(bvec1, byteWriter.getVectors().get(1));
Assert.assertEquals(bvec2, byteWriter.getVectors().get(2));
Mockito.verify(fieldInfo, Mockito.times(2)).getVectorEncoding();
Mockito.verify(mockedFlatFieldByteVectorsWriter).addValue(1, bvec1);
Mockito.verify(mockedFlatFieldByteVectorsWriter).addValue(2, bvec2);
}

@SuppressWarnings("unchecked")
@SneakyThrows
public void testCopyValue_whenValidInput_thenException() {
final FieldInfo fieldInfo = Mockito.mock(FieldInfo.class);
FlatFieldVectorsWriter<?> mockedFlatFieldVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class);
Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.FLOAT32);
final NativeEngineFieldVectorsWriter<float[]> floatWriter = (NativeEngineFieldVectorsWriter<float[]>) NativeEngineFieldVectorsWriter
.create(fieldInfo, InfoStream.getDefault());
.create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault());
expectThrows(UnsupportedOperationException.class, () -> floatWriter.copyValue(new float[3]));

Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.BYTE);
final NativeEngineFieldVectorsWriter<byte[]> byteWriter = (NativeEngineFieldVectorsWriter<byte[]>) NativeEngineFieldVectorsWriter
.create(fieldInfo, InfoStream.getDefault());
.create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault());
expectThrows(UnsupportedOperationException.class, () -> byteWriter.copyValue(new byte[3]));
}

@SuppressWarnings("unchecked")
@SneakyThrows
public void testRamByteUsed_whenValidInput_thenSuccess() {
final FieldInfo fieldInfo = Mockito.mock(FieldInfo.class);
Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.FLOAT32);
Mockito.when(fieldInfo.getVectorDimension()).thenReturn(2);
FlatFieldVectorsWriter<?> mockedFlatFieldVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class);
Mockito.when(mockedFlatFieldVectorsWriter.ramBytesUsed()).thenReturn(1L);
final NativeEngineFieldVectorsWriter<float[]> floatWriter = (NativeEngineFieldVectorsWriter<float[]>) NativeEngineFieldVectorsWriter
.create(fieldInfo, InfoStream.getDefault());
.create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault());
// testing for value > 0 as we don't have a concrete way to find out expected bytes. This can OS dependent too.
Assert.assertTrue(floatWriter.ramBytesUsed() > 0);

Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.BYTE);
final NativeEngineFieldVectorsWriter<byte[]> byteWriter = (NativeEngineFieldVectorsWriter<byte[]>) NativeEngineFieldVectorsWriter
.create(fieldInfo, InfoStream.getDefault());
.create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault());
// testing for value > 0 as we don't have a concrete way to find out expected bytes. This can OS dependent too.
Assert.assertTrue(byteWriter.ramBytesUsed() > 0);
Mockito.verify(mockedFlatFieldVectorsWriter, Mockito.times(2)).ramBytesUsed();

}
}
Loading

0 comments on commit 8baeaba

Please sign in to comment.