From 494b16063e1d06e3018e0e0e70168e2813f86f03 Mon Sep 17 00:00:00 2001 From: panguixin Date: Thu, 31 Oct 2024 23:16:09 +0800 Subject: [PATCH] Replace Map with IntObjectHashMap for KnnVectorsReader (#13763) --- lucene/CHANGES.txt | 2 + .../lucene90/Lucene90HnswVectorsReader.java | 27 ++++---- .../lucene91/Lucene91HnswVectorsReader.java | 27 ++++---- .../lucene92/Lucene92HnswVectorsReader.java | 27 ++++---- .../lucene94/Lucene94HnswVectorsReader.java | 51 +++++++-------- .../lucene95/Lucene95HnswVectorsReader.java | 63 ++++++++----------- .../SimpleTextKnnVectorsReader.java | 13 ++-- .../lucene99/Lucene99FlatVectorsReader.java | 55 +++++++--------- .../lucene99/Lucene99HnswVectorsReader.java | 51 +++++++++------ .../Lucene99ScalarQuantizedVectorsReader.java | 49 +++++++-------- .../perfield/PerFieldKnnVectorsFormat.java | 58 ++++++++++++----- 11 files changed, 218 insertions(+), 205 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 23ee3e92a855..b512720bc9ca 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -82,6 +82,8 @@ Optimizations * GITHUB#13958: Speed up advancing within a block. (Adrien Grand) +* GITHUB#13763: Replace Map with IntObjectHashMap for KnnVectorsReader (Pan Guixin) + Bug Fixes --------------------- * GITHUB#13832: Fixed an issue where the DefaultPassageFormatter.format method did not format passages as intended diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java index 3ffd4f4d75a0..015fad7490ce 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java @@ -20,8 +20,6 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; -import java.util.HashMap; -import java.util.Map; import java.util.SplittableRandom; import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.KnnVectorsReader; @@ -33,6 +31,7 @@ import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.hppc.IntObjectHashMap; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.ChecksumIndexInput; @@ -50,14 +49,16 @@ */ public final class Lucene90HnswVectorsReader extends KnnVectorsReader { - private final Map fields = new HashMap<>(); + private final IntObjectHashMap fields = new IntObjectHashMap<>(); private final IndexInput vectorData; private final IndexInput vectorIndex; private final long checksumSeed; + private final FieldInfos fieldInfos; Lucene90HnswVectorsReader(SegmentReadState state) throws IOException { int versionMeta = readMetadata(state); long[] checksumRef = new long[1]; + this.fieldInfos = state.fieldInfos; boolean success = false; try { vectorData = @@ -158,7 +159,7 @@ private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOExce FieldEntry fieldEntry = readField(meta, info); validateFieldEntry(info, fieldEntry); - fields.put(info.name, fieldEntry); + fields.put(info.number, fieldEntry); } } @@ -218,13 +219,18 @@ public void checkIntegrity() throws IOException { CodecUtil.checksumEntireFile(vectorIndex); } - @Override - public FloatVectorValues getFloatVectorValues(String field) throws IOException { - FieldEntry fieldEntry = fields.get(field); - if (fieldEntry == null) { + private FieldEntry getFieldEntry(String field) { + final FieldInfo info = fieldInfos.fieldInfo(field); + final FieldEntry fieldEntry; + if (info == null || (fieldEntry = fields.get(info.number)) == null) { throw new IllegalArgumentException("field=\"" + field + "\" not found"); } - return getOffHeapVectorValues(fieldEntry); + return fieldEntry; + } + + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + return getOffHeapVectorValues(getFieldEntry(field)); } @Override @@ -235,8 +241,7 @@ public ByteVectorValues getByteVectorValues(String field) { @Override public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { - FieldEntry fieldEntry = fields.get(field); - + final FieldEntry fieldEntry = getFieldEntry(field); if (fieldEntry.size() == 0) { return; } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java index a140b4fd7f39..e71fa66719f8 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java @@ -21,8 +21,6 @@ import java.io.IOException; import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; import java.util.function.IntUnaryOperator; import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.KnnVectorsReader; @@ -35,6 +33,7 @@ import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.hppc.IntObjectHashMap; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.VectorScorer; @@ -55,13 +54,15 @@ */ public final class Lucene91HnswVectorsReader extends KnnVectorsReader { - private final Map fields = new HashMap<>(); + private final IntObjectHashMap fields = new IntObjectHashMap<>(); private final IndexInput vectorData; private final IndexInput vectorIndex; private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer(); + private final FieldInfos fieldInfos; Lucene91HnswVectorsReader(SegmentReadState state) throws IOException { int versionMeta = readMetadata(state); + this.fieldInfos = state.fieldInfos; boolean success = false; try { vectorData = @@ -154,7 +155,7 @@ private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOExce } FieldEntry fieldEntry = readField(meta, info); validateFieldEntry(info, fieldEntry); - fields.put(info.name, fieldEntry); + fields.put(info.number, fieldEntry); } } @@ -214,13 +215,18 @@ public void checkIntegrity() throws IOException { CodecUtil.checksumEntireFile(vectorIndex); } - @Override - public FloatVectorValues getFloatVectorValues(String field) throws IOException { - FieldEntry fieldEntry = fields.get(field); - if (fieldEntry == null) { + private FieldEntry getFieldEntry(String field) { + final FieldInfo info = fieldInfos.fieldInfo(field); + final FieldEntry fieldEntry; + if (info == null || (fieldEntry = fields.get(info.number)) == null) { throw new IllegalArgumentException("field=\"" + field + "\" not found"); } - return getOffHeapVectorValues(fieldEntry); + return fieldEntry; + } + + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + return getOffHeapVectorValues(getFieldEntry(field)); } @Override @@ -231,8 +237,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { @Override public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { - FieldEntry fieldEntry = fields.get(field); - + final FieldEntry fieldEntry = getFieldEntry(field); if (fieldEntry.size() == 0) { return; } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java index 39fe109a9f13..034967efbaab 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java @@ -21,8 +21,6 @@ import java.io.IOException; import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; @@ -34,6 +32,7 @@ import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.hppc.IntObjectHashMap; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataInput; @@ -53,13 +52,15 @@ */ public final class Lucene92HnswVectorsReader extends KnnVectorsReader { - private final Map fields = new HashMap<>(); + private final IntObjectHashMap fields = new IntObjectHashMap<>(); private final IndexInput vectorData; private final IndexInput vectorIndex; private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer(); + private final FieldInfos fieldInfos; Lucene92HnswVectorsReader(SegmentReadState state) throws IOException { int versionMeta = readMetadata(state); + this.fieldInfos = state.fieldInfos; boolean success = false; try { vectorData = @@ -152,7 +153,7 @@ private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOExce } FieldEntry fieldEntry = readField(meta, info); validateFieldEntry(info, fieldEntry); - fields.put(info.name, fieldEntry); + fields.put(info.number, fieldEntry); } } @@ -212,13 +213,18 @@ public void checkIntegrity() throws IOException { CodecUtil.checksumEntireFile(vectorIndex); } - @Override - public FloatVectorValues getFloatVectorValues(String field) throws IOException { - FieldEntry fieldEntry = fields.get(field); - if (fieldEntry == null) { + private FieldEntry getFieldEntry(String field) { + final FieldInfo info = fieldInfos.fieldInfo(field); + final FieldEntry fieldEntry; + if (info == null || (fieldEntry = fields.get(info.number)) == null) { throw new IllegalArgumentException("field=\"" + field + "\" not found"); } - return OffHeapFloatVectorValues.load(fieldEntry, vectorData); + return fieldEntry; + } + + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + return OffHeapFloatVectorValues.load(getFieldEntry(field), vectorData); } @Override @@ -229,8 +235,7 @@ public ByteVectorValues getByteVectorValues(String field) { @Override public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { - FieldEntry fieldEntry = fields.get(field); - + final FieldEntry fieldEntry = getFieldEntry(field); if (fieldEntry.size() == 0) { return; } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java index d5beae1e6811..1ad2e3023642 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java @@ -21,8 +21,6 @@ import java.io.IOException; import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; @@ -35,6 +33,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.hppc.IntObjectHashMap; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataInput; @@ -54,13 +53,15 @@ */ public final class Lucene94HnswVectorsReader extends KnnVectorsReader { - private final Map fields = new HashMap<>(); + private final IntObjectHashMap fields = new IntObjectHashMap<>(); private final IndexInput vectorData; private final IndexInput vectorIndex; private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer(); + private final FieldInfos fieldInfos; Lucene94HnswVectorsReader(SegmentReadState state) throws IOException { int versionMeta = readMetadata(state); + this.fieldInfos = state.fieldInfos; boolean success = false; try { vectorData = @@ -153,7 +154,7 @@ private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOExce } FieldEntry fieldEntry = readField(meta, info); validateFieldEntry(info, fieldEntry); - fields.put(info.name, fieldEntry); + fields.put(info.number, fieldEntry); } } @@ -230,48 +231,41 @@ public void checkIntegrity() throws IOException { CodecUtil.checksumEntireFile(vectorIndex); } - @Override - public FloatVectorValues getFloatVectorValues(String field) throws IOException { - FieldEntry fieldEntry = fields.get(field); - if (fieldEntry == null) { + private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding) { + final FieldInfo info = fieldInfos.fieldInfo(field); + final FieldEntry fieldEntry; + if (info == null || (fieldEntry = fields.get(info.number)) == null) { throw new IllegalArgumentException("field=\"" + field + "\" not found"); } - if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { + if (fieldEntry.vectorEncoding != expectedEncoding) { throw new IllegalArgumentException( "field=\"" + field + "\" is encoded as: " + fieldEntry.vectorEncoding + " expected: " - + VectorEncoding.FLOAT32); + + expectedEncoding); } + return fieldEntry; + } + + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32); return OffHeapFloatVectorValues.load(fieldEntry, vectorData); } @Override public ByteVectorValues getByteVectorValues(String field) throws IOException { - FieldEntry fieldEntry = fields.get(field); - if (fieldEntry == null) { - throw new IllegalArgumentException("field=\"" + field + "\" not found"); - } - if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) { - throw new IllegalArgumentException( - "field=\"" - + field - + "\" is encoded as: " - + fieldEntry.vectorEncoding - + " expected: " - + VectorEncoding.BYTE); - } + final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE); return OffHeapByteVectorValues.load(fieldEntry, vectorData); } @Override public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { - FieldEntry fieldEntry = fields.get(field); - - if (fieldEntry.size() == 0 || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { + final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32); + if (fieldEntry.size() == 0 || knnCollector.k() == 0) { return; } @@ -289,9 +283,8 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits @Override public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { - FieldEntry fieldEntry = fields.get(field); - - if (fieldEntry.size() == 0 || fieldEntry.vectorEncoding != VectorEncoding.BYTE) { + final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE); + if (fieldEntry.size() == 0 || knnCollector.k() == 0) { return; } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java index 2e6714d6eb8e..b5859daf9f2f 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java @@ -21,8 +21,6 @@ import java.io.IOException; import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; @@ -39,6 +37,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.hppc.IntObjectHashMap; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataInput; @@ -61,7 +60,7 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements HnswGraphProvider { private final FieldInfos fieldInfos; - private final Map fields = new HashMap<>(); + private final IntObjectHashMap fields = new IntObjectHashMap<>(); private final IndexInput vectorData; private final IndexInput vectorIndex; private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer(); @@ -161,7 +160,7 @@ private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOExce } FieldEntry fieldEntry = readField(meta, info); validateFieldEntry(info, fieldEntry); - fields.put(info.name, fieldEntry); + fields.put(info.number, fieldEntry); } } @@ -238,21 +237,27 @@ public void checkIntegrity() throws IOException { CodecUtil.checksumEntireFile(vectorIndex); } - @Override - public FloatVectorValues getFloatVectorValues(String field) throws IOException { - FieldEntry fieldEntry = fields.get(field); - if (fieldEntry == null) { + private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding) { + final FieldInfo info = fieldInfos.fieldInfo(field); + final FieldEntry fieldEntry; + if (info == null || (fieldEntry = fields.get(info.number)) == null) { throw new IllegalArgumentException("field=\"" + field + "\" not found"); } - if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { + if (fieldEntry.vectorEncoding != expectedEncoding) { throw new IllegalArgumentException( "field=\"" + field + "\" is encoded as: " + fieldEntry.vectorEncoding + " expected: " - + VectorEncoding.FLOAT32); + + expectedEncoding); } + return fieldEntry; + } + + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32); return OffHeapFloatVectorValues.load( fieldEntry.similarityFunction, defaultFlatVectorScorer, @@ -266,19 +271,7 @@ public FloatVectorValues getFloatVectorValues(String field) throws IOException { @Override public ByteVectorValues getByteVectorValues(String field) throws IOException { - FieldEntry fieldEntry = fields.get(field); - if (fieldEntry == null) { - throw new IllegalArgumentException("field=\"" + field + "\" not found"); - } - if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) { - throw new IllegalArgumentException( - "field=\"" - + field - + "\" is encoded as: " - + fieldEntry.vectorEncoding - + " expected: " - + VectorEncoding.BYTE); - } + final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE); return OffHeapByteVectorValues.load( fieldEntry.similarityFunction, defaultFlatVectorScorer, @@ -293,11 +286,8 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { @Override public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { - FieldEntry fieldEntry = fields.get(field); - - if (fieldEntry.size() == 0 - || knnCollector.k() == 0 - || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { + final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32); + if (fieldEntry.size() == 0 || knnCollector.k() == 0) { return; } @@ -324,11 +314,8 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits @Override public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { - FieldEntry fieldEntry = fields.get(field); - - if (fieldEntry.size() == 0 - || knnCollector.k() == 0 - || fieldEntry.vectorEncoding != VectorEncoding.BYTE) { + final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE); + if (fieldEntry.size() == 0 || knnCollector.k() == 0) { return; } @@ -355,12 +342,12 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits /** Get knn graph values; used for testing */ @Override public HnswGraph getGraph(String field) throws IOException { - FieldInfo info = fieldInfos.fieldInfo(field); - if (info == null) { - throw new IllegalArgumentException("No such field '" + field + "'"); + final FieldInfo info = fieldInfos.fieldInfo(field); + final FieldEntry entry; + if (info == null || (entry = fields.get(info.number)) == null) { + throw new IllegalArgumentException("field=\"" + field + "\" not found"); } - FieldEntry entry = fields.get(field); - if (entry != null && entry.vectorIndexLength > 0) { + if (entry.vectorIndexLength > 0) { return getGraph(entry); } else { return HnswGraph.EMPTY; diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java index 0a8c48363212..6c7c53a38d0e 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java @@ -26,8 +26,6 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; -import java.util.HashMap; -import java.util.Map; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.CorruptIndexException; @@ -36,6 +34,7 @@ import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.hppc.IntObjectHashMap; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.VectorScorer; @@ -63,7 +62,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader { private final SegmentReadState readState; private final IndexInput dataIn; private final BytesRefBuilder scratch = new BytesRefBuilder(); - private final Map fieldEntries = new HashMap<>(); + private final IntObjectHashMap fieldEntries = new IntObjectHashMap<>(); SimpleTextKnnVectorsReader(SegmentReadState readState) throws IOException { this.readState = readState; @@ -91,9 +90,9 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader { for (int i = 0; i < size; i++) { docIds[i] = readInt(in, EMPTY); } - assert fieldEntries.containsKey(fieldName) == false; + assert fieldEntries.containsKey(fieldNumber) == false; fieldEntries.put( - fieldName, + fieldNumber, new FieldEntry( dimension, vectorDataOffset, @@ -126,7 +125,7 @@ public FloatVectorValues getFloatVectorValues(String field) throws IOException { throw new IllegalStateException( "KNN vectors readers should not be called on fields that don't enable KNN vectors"); } - FieldEntry fieldEntry = fieldEntries.get(field); + FieldEntry fieldEntry = fieldEntries.get(info.number); if (fieldEntry == null) { // mirror the handling in Lucene90VectorReader#getVectorValues // needed to pass TestSimpleTextKnnVectorsFormat#testDeleteAllVectorDocs @@ -159,7 +158,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { throw new IllegalStateException( "KNN vectors readers should not be called on fields that don't enable KNN vectors"); } - FieldEntry fieldEntry = fieldEntries.get(field); + FieldEntry fieldEntry = fieldEntries.get(info.number); if (fieldEntry == null) { // mirror the handling in Lucene90VectorReader#getVectorValues // needed to pass TestSimpleTextKnnVectorsFormat#testDeleteAllVectorDocs diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsReader.java index b334298cb8f2..9b42ddd0f267 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsReader.java @@ -21,8 +21,6 @@ import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding; import java.io.IOException; -import java.util.HashMap; -import java.util.Map; import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; @@ -38,6 +36,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.hppc.IntObjectHashMap; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexInput; @@ -56,13 +55,15 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader { private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Lucene99FlatVectorsFormat.class); - private final Map fields = new HashMap<>(); + private final IntObjectHashMap fields = new IntObjectHashMap<>(); private final IndexInput vectorData; + private final FieldInfos fieldInfos; public Lucene99FlatVectorsReader(SegmentReadState state, FlatVectorsScorer scorer) throws IOException { super(scorer); int versionMeta = readMetadata(state); + this.fieldInfos = state.fieldInfos; boolean success = false; try { vectorData = @@ -155,15 +156,13 @@ private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOExce throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); } FieldEntry fieldEntry = FieldEntry.create(meta, info); - fields.put(info.name, fieldEntry); + fields.put(info.number, fieldEntry); } } @Override public long ramBytesUsed() { - return Lucene99FlatVectorsReader.SHALLOW_SIZE - + RamUsageEstimator.sizeOfMap( - fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class)); + return Lucene99FlatVectorsReader.SHALLOW_SIZE + fields.ramBytesUsed(); } @Override @@ -171,21 +170,27 @@ public void checkIntegrity() throws IOException { CodecUtil.checksumEntireFile(vectorData); } - @Override - public FloatVectorValues getFloatVectorValues(String field) throws IOException { - FieldEntry fieldEntry = fields.get(field); - if (fieldEntry == null) { + private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding) { + final FieldInfo info = fieldInfos.fieldInfo(field); + final FieldEntry fieldEntry; + if (info == null || (fieldEntry = fields.get(info.number)) == null) { throw new IllegalArgumentException("field=\"" + field + "\" not found"); } - if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { + if (fieldEntry.vectorEncoding != expectedEncoding) { throw new IllegalArgumentException( "field=\"" + field + "\" is encoded as: " + fieldEntry.vectorEncoding + " expected: " - + VectorEncoding.FLOAT32); + + expectedEncoding); } + return fieldEntry; + } + + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32); return OffHeapFloatVectorValues.load( fieldEntry.similarityFunction, vectorScorer, @@ -199,19 +204,7 @@ public FloatVectorValues getFloatVectorValues(String field) throws IOException { @Override public ByteVectorValues getByteVectorValues(String field) throws IOException { - FieldEntry fieldEntry = fields.get(field); - if (fieldEntry == null) { - throw new IllegalArgumentException("field=\"" + field + "\" not found"); - } - if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) { - throw new IllegalArgumentException( - "field=\"" - + field - + "\" is encoded as: " - + fieldEntry.vectorEncoding - + " expected: " - + VectorEncoding.BYTE); - } + final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE); return OffHeapByteVectorValues.load( fieldEntry.similarityFunction, vectorScorer, @@ -225,10 +218,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { @Override public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException { - FieldEntry fieldEntry = fields.get(field); - if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { - return null; - } + final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32); return vectorScorer.getRandomVectorScorer( fieldEntry.similarityFunction, OffHeapFloatVectorValues.load( @@ -245,10 +235,7 @@ public RandomVectorScorer getRandomVectorScorer(String field, float[] target) th @Override public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException { - FieldEntry fieldEntry = fields.get(field); - if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.BYTE) { - return null; - } + final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE); return vectorScorer.getRandomVectorScorer( fieldEntry.similarityFunction, OffHeapByteVectorValues.load( diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java index f27a826e9c35..2a3088527f5f 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java @@ -21,9 +21,7 @@ import java.io.IOException; import java.util.Arrays; -import java.util.HashMap; import java.util.List; -import java.util.Map; import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; @@ -37,6 +35,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.hppc.IntObjectHashMap; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataInput; @@ -70,7 +69,7 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader private final FlatVectorsReader flatVectorsReader; private final FieldInfos fieldInfos; - private final Map fields = new HashMap<>(); + private final IntObjectHashMap fields = new IntObjectHashMap<>(); private final IndexInput vectorIndex; public Lucene99HnswVectorsReader(SegmentReadState state, FlatVectorsReader flatVectorsReader) @@ -162,7 +161,7 @@ private void readFields(ChecksumIndexInput meta) throws IOException { } FieldEntry fieldEntry = readField(meta, info); validateFieldEntry(info, fieldEntry); - fields.put(info.name, fieldEntry); + fields.put(info.number, fieldEntry); } } @@ -225,8 +224,7 @@ private FieldEntry readField(IndexInput input, FieldInfo info) throws IOExceptio @Override public long ramBytesUsed() { return Lucene99HnswVectorsReader.SHALLOW_SIZE - + RamUsageEstimator.sizeOfMap( - fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class)) + + fields.ramBytesUsed() + flatVectorsReader.ramBytesUsed(); } @@ -246,25 +244,43 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { return flatVectorsReader.getByteVectorValues(field); } + private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding) { + final FieldInfo info = fieldInfos.fieldInfo(field); + final FieldEntry fieldEntry; + if (info == null || (fieldEntry = fields.get(info.number)) == null) { + throw new IllegalArgumentException("field=\"" + field + "\" not found"); + } + if (fieldEntry.vectorEncoding != expectedEncoding) { + throw new IllegalArgumentException( + "field=\"" + + field + + "\" is encoded as: " + + fieldEntry.vectorEncoding + + " expected: " + + expectedEncoding); + } + return fieldEntry; + } + @Override public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32); search( - fields.get(field), + fieldEntry, knnCollector, acceptDocs, - VectorEncoding.FLOAT32, () -> flatVectorsReader.getRandomVectorScorer(field, target)); } @Override public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE); search( - fields.get(field), + fieldEntry, knnCollector, acceptDocs, - VectorEncoding.BYTE, () -> flatVectorsReader.getRandomVectorScorer(field, target)); } @@ -272,13 +288,10 @@ private void search( FieldEntry fieldEntry, KnnCollector knnCollector, Bits acceptDocs, - VectorEncoding vectorEncoding, IOSupplier scorerSupplier) throws IOException { - if (fieldEntry.size() == 0 - || knnCollector.k() == 0 - || fieldEntry.vectorEncoding != vectorEncoding) { + if (fieldEntry.size() == 0 || knnCollector.k() == 0) { return; } final RandomVectorScorer scorer = scorerSupplier.get(); @@ -304,12 +317,12 @@ private void search( @Override public HnswGraph getGraph(String field) throws IOException { - FieldInfo info = fieldInfos.fieldInfo(field); - if (info == null) { - throw new IllegalArgumentException("No such field '" + field + "'"); + final FieldInfo info = fieldInfos.fieldInfo(field); + final FieldEntry entry; + if (info == null || (entry = fields.get(info.number)) == null) { + throw new IllegalArgumentException("field=\"" + field + "\" not found"); } - FieldEntry entry = fields.get(field); - if (entry != null && entry.vectorIndexLength > 0) { + if (entry.vectorIndexLength > 0) { return getGraph(entry); } else { return HnswGraph.EMPTY; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java index 32eea942e2a0..712e9b91f9d2 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java @@ -21,8 +21,6 @@ import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding; import java.io.IOException; -import java.util.HashMap; -import java.util.Map; import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; @@ -36,6 +34,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.hppc.IntObjectHashMap; import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.IOContext; @@ -59,15 +58,17 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Lucene99ScalarQuantizedVectorsReader.class); - private final Map fields = new HashMap<>(); + private final IntObjectHashMap fields = new IntObjectHashMap<>(); private final IndexInput quantizedVectorData; private final FlatVectorsReader rawVectorsReader; + private final FieldInfos fieldInfos; public Lucene99ScalarQuantizedVectorsReader( SegmentReadState state, FlatVectorsReader rawVectorsReader, FlatVectorsScorer scorer) throws IOException { super(scorer); this.rawVectorsReader = rawVectorsReader; + this.fieldInfos = state.fieldInfos; int versionMeta = -1; String metaFileName = IndexFileNames.segmentFileName( @@ -118,7 +119,7 @@ private void readFields(ChecksumIndexInput meta, int versionMeta, FieldInfos inf } FieldEntry fieldEntry = readField(meta, versionMeta, info); validateFieldEntry(info, fieldEntry); - fields.put(info.name, fieldEntry); + fields.put(info.number, fieldEntry); } } @@ -163,10 +164,10 @@ public void checkIntegrity() throws IOException { CodecUtil.checksumEntireFile(quantizedVectorData); } - @Override - public FloatVectorValues getFloatVectorValues(String field) throws IOException { - FieldEntry fieldEntry = fields.get(field); - if (fieldEntry == null) { + private FieldEntry getFieldEntry(String field) { + final FieldInfo info = fieldInfos.fieldInfo(field); + final FieldEntry fieldEntry; + if (info == null || (fieldEntry = fields.get(info.number)) == null) { throw new IllegalArgumentException("field=\"" + field + "\" not found"); } if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { @@ -178,6 +179,12 @@ public FloatVectorValues getFloatVectorValues(String field) throws IOException { + " expected: " + VectorEncoding.FLOAT32); } + return fieldEntry; + } + + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + final FieldEntry fieldEntry = getFieldEntry(field); final FloatVectorValues rawVectorValues = rawVectorsReader.getFloatVectorValues(field); OffHeapQuantizedByteVectorValues quantizedByteVectorValues = OffHeapQuantizedByteVectorValues.load( @@ -241,10 +248,7 @@ private static IndexInput openDataInput( @Override public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException { - FieldEntry fieldEntry = fields.get(field); - if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { - return null; - } + final FieldEntry fieldEntry = getFieldEntry(field); if (fieldEntry.scalarQuantizer == null) { return rawVectorsReader.getRandomVectorScorer(field, target); } @@ -275,12 +279,7 @@ public void close() throws IOException { @Override public long ramBytesUsed() { - long size = SHALLOW_SIZE; - size += - RamUsageEstimator.sizeOfMap( - fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class)); - size += rawVectorsReader.ramBytesUsed(); - return size; + return SHALLOW_SIZE + fields.ramBytesUsed() + rawVectorsReader.ramBytesUsed(); } private FieldEntry readField(IndexInput input, int versionMeta, FieldInfo info) @@ -301,11 +300,8 @@ private FieldEntry readField(IndexInput input, int versionMeta, FieldInfo info) } @Override - public QuantizedByteVectorValues getQuantizedVectorValues(String fieldName) throws IOException { - FieldEntry fieldEntry = fields.get(fieldName); - if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { - return null; - } + public QuantizedByteVectorValues getQuantizedVectorValues(String field) throws IOException { + final FieldEntry fieldEntry = getFieldEntry(field); return OffHeapQuantizedByteVectorValues.load( fieldEntry.ordToDoc, fieldEntry.dimension, @@ -320,11 +316,8 @@ public QuantizedByteVectorValues getQuantizedVectorValues(String fieldName) thro } @Override - public ScalarQuantizer getQuantizationState(String fieldName) { - FieldEntry fieldEntry = fields.get(fieldName); - if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { - return null; - } + public ScalarQuantizer getQuantizationState(String field) { + final FieldEntry fieldEntry = getFieldEntry(field); return fieldEntry.scalarQuantizer; } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java index 5dc4db8db6a8..63bad6d48dad 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java @@ -19,7 +19,9 @@ import java.io.Closeable; import java.io.IOException; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.ServiceLoader; import org.apache.lucene.codecs.KnnFieldVectorsWriter; @@ -28,11 +30,14 @@ import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; +import org.apache.lucene.internal.hppc.IntObjectHashMap; +import org.apache.lucene.internal.hppc.ObjectCursor; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.Bits; import org.apache.lucene.util.IOUtils; @@ -186,7 +191,8 @@ public long ramBytesUsed() { /** VectorReader that can wrap multiple delegate readers, selected by field. */ public static class FieldsReader extends KnnVectorsReader { - private final Map fields = new HashMap<>(); + private final IntObjectHashMap fields = new IntObjectHashMap<>(); + private final FieldInfos fieldInfos; /** * Create a FieldsReader over a segment, opening VectorReaders for each KnnVectorsFormat @@ -196,7 +202,7 @@ public static class FieldsReader extends KnnVectorsReader { * @throws IOException if one of the delegate readers throws */ public FieldsReader(final SegmentReadState readState) throws IOException { - + this.fieldInfos = readState.fieldInfos; // Init each unique format: boolean success = false; Map formats = new HashMap<>(); @@ -221,7 +227,7 @@ public FieldsReader(final SegmentReadState readState) throws IOException { segmentSuffix, format.fieldsReader(new SegmentReadState(readState, segmentSuffix))); } - fields.put(fieldName, formats.get(segmentSuffix)); + fields.put(fi.number, formats.get(segmentSuffix)); } } } @@ -239,51 +245,69 @@ public FieldsReader(final SegmentReadState readState) throws IOException { * @param field the name of a numeric vector field */ public KnnVectorsReader getFieldReader(String field) { - return fields.get(field); + final FieldInfo info = fieldInfos.fieldInfo(field); + if (info == null) { + return null; + } + return fields.get(info.number); } @Override public void checkIntegrity() throws IOException { - for (KnnVectorsReader reader : fields.values()) { - reader.checkIntegrity(); + for (ObjectCursor cursor : fields.values()) { + cursor.value.checkIntegrity(); } } @Override public FloatVectorValues getFloatVectorValues(String field) throws IOException { - KnnVectorsReader knnVectorsReader = fields.get(field); - if (knnVectorsReader == null) { + final FieldInfo info = fieldInfos.fieldInfo(field); + final KnnVectorsReader reader; + if (info == null || (reader = fields.get(info.number)) == null) { return null; - } else { - return knnVectorsReader.getFloatVectorValues(field); } + return reader.getFloatVectorValues(field); } @Override public ByteVectorValues getByteVectorValues(String field) throws IOException { - KnnVectorsReader knnVectorsReader = fields.get(field); - if (knnVectorsReader == null) { + final FieldInfo info = fieldInfos.fieldInfo(field); + final KnnVectorsReader reader; + if (info == null || (reader = fields.get(info.number)) == null) { return null; - } else { - return knnVectorsReader.getByteVectorValues(field); } + return reader.getByteVectorValues(field); } @Override public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { - fields.get(field).search(field, target, knnCollector, acceptDocs); + final FieldInfo info = fieldInfos.fieldInfo(field); + final KnnVectorsReader reader; + if (info == null || (reader = fields.get(info.number)) == null) { + return; + } + reader.search(field, target, knnCollector, acceptDocs); } @Override public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { - fields.get(field).search(field, target, knnCollector, acceptDocs); + final FieldInfo info = fieldInfos.fieldInfo(field); + final KnnVectorsReader reader; + if (info == null || (reader = fields.get(info.number)) == null) { + return; + } + reader.search(field, target, knnCollector, acceptDocs); } @Override public void close() throws IOException { - IOUtils.close(fields.values()); + List readers = new ArrayList<>(fields.size()); + for (ObjectCursor cursor : fields.values()) { + readers.add(cursor.value); + } + IOUtils.close(readers); } }