From 1778377176b901d70d994b71de30839acae224cb Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Thu, 9 Jan 2025 15:59:44 -0500 Subject: [PATCH] fix gh-14123: Add null checks to SortingCodecReader (#14125) --- lucene/CHANGES.txt | 2 + .../lucene/index/SortingCodecReader.java | 22 ++++++++++ .../lucene/index/TestSortingCodecReader.java | 40 +++++++++++++------ 3 files changed, 51 insertions(+), 13 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index c569db7f7ad0..0e7c6b13ee5b 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -66,6 +66,8 @@ Bug Fixes * GITHUB#14109: prefetch may select the wrong memory segment for multi-segment slices. (Chris Hegarty) +* GITHUB#14123: SortingCodecReader NPE when segment has no (points, vectors, etc...) (Mike Sokolov) + Other --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java index daec0c197d6a..ab9964026ad8 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java @@ -314,6 +314,7 @@ private static class SortingFloatVectorValues extends FloatVectorValues { SortingFloatVectorValues(FloatVectorValues delegate, Sorter.DocMap sortMap) throws IOException { this.delegate = delegate; + assert delegate != null; // SortingValuesIterator consumes the iterator and records the docs and ord mapping iteratorSupplier = iteratorSupplier(delegate, sortMap); } @@ -446,6 +447,9 @@ private SortingCodecReader( @Override public FieldsProducer getPostingsReader() { FieldsProducer postingsReader = in.getPostingsReader(); + if (postingsReader == null) { + return null; + } return new FieldsProducer() { @Override public void close() throws IOException { @@ -481,6 +485,9 @@ public int size() { @Override public StoredFieldsReader getFieldsReader() { StoredFieldsReader delegate = in.getFieldsReader(); + if (delegate == null) { + return null; + } return newStoredFieldsReader(delegate); } @@ -526,6 +533,9 @@ public Bits getLiveDocs() { @Override public PointsReader getPointsReader() { final PointsReader delegate = in.getPointsReader(); + if (delegate == null) { + return null; + } return new PointsReader() { @Override public void checkIntegrity() throws IOException { @@ -551,6 +561,9 @@ public void close() throws IOException { @Override public KnnVectorsReader getVectorReader() { KnnVectorsReader delegate = in.getVectorReader(); + if (delegate == null) { + return null; + } return new KnnVectorsReader() { @Override public void checkIntegrity() throws IOException { @@ -587,6 +600,9 @@ public void close() throws IOException { @Override public NormsProducer getNormsReader() { final NormsProducer delegate = in.getNormsReader(); + if (delegate == null) { + return null; + } return new NormsProducer() { @Override public NumericDocValues getNorms(FieldInfo field) throws IOException { @@ -609,6 +625,9 @@ public void close() throws IOException { @Override public DocValuesProducer getDocValuesReader() { final DocValuesProducer delegate = in.getDocValuesReader(); + if (delegate == null) { + return null; + } return new DocValuesProducer() { @Override public NumericDocValues getNumeric(FieldInfo field) throws IOException { @@ -710,6 +729,9 @@ public TermVectorsReader getTermVectorsReader() { } private TermVectorsReader newTermVectorsReader(TermVectorsReader delegate) { + if (delegate == null) { + return null; + } return new TermVectorsReader() { @Override public void prefetch(int doc) throws IOException { diff --git a/lucene/core/src/test/org/apache/lucene/index/TestSortingCodecReader.java b/lucene/core/src/test/org/apache/lucene/index/TestSortingCodecReader.java index 8039d8b8f6fb..285296d55c19 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestSortingCodecReader.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestSortingCodecReader.java @@ -25,6 +25,7 @@ import java.util.Collections; import java.util.List; import java.util.Locale; +import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.TermVectorsReader; import org.apache.lucene.codecs.hnsw.HnswGraphProvider; import org.apache.lucene.document.BinaryDocValuesField; @@ -153,14 +154,16 @@ public void testSortOnAddIndicesRandom() throws IOException { docIds.add(i); } Collections.shuffle(docIds, random()); - // If true, index a vector for every doc - boolean denseVectors = random().nextBoolean(); + // If true, index a vector and points for every doc + boolean dense = random().nextBoolean(); try (RandomIndexWriter iw = new RandomIndexWriter(random(), dir)) { for (int i = 0; i < numDocs; i++) { int docId = docIds.get(i); Document doc = new Document(); doc.add(new StringField("string_id", Integer.toString(docId), Field.Store.YES)); - doc.add(new LongPoint("point_id", docId)); + if (dense || docId % 3 == 0) { + doc.add(new LongPoint("point_id", docId)); + } String s = RandomStrings.randomRealisticUnicodeOfLength(random(), 25); doc.add(new TextField("text_field", s, Field.Store.YES)); doc.add(new BinaryDocValuesField("text_field", new BytesRef(s))); @@ -172,7 +175,7 @@ public void testSortOnAddIndicesRandom() throws IOException { doc.add(new BinaryDocValuesField("binary_dv", new BytesRef(Integer.toString(docId)))); doc.add( new SortedSetDocValuesField("sorted_set_dv", new BytesRef(Integer.toString(docId)))); - if (denseVectors || docId % 2 == 0) { + if (dense || docId % 2 == 0) { doc.add(new KnnFloatVectorField("vector", new float[] {(float) docId})); } doc.add(new NumericDocValuesField("foo", random().nextInt(20))); @@ -245,8 +248,13 @@ public void testSortOnAddIndicesRandom() throws IOException { SortedSetDocValues sorted_set_dv = leaf.getSortedSetDocValues("sorted_set_dv"); SortedDocValues binary_sorted_dv = leaf.getSortedDocValues("binary_sorted_dv"); FloatVectorValues vectorValues = leaf.getFloatVectorValues("vector"); - HnswGraph graph = - ((HnswGraphProvider) ((CodecReader) leaf).getVectorReader()).getGraph("vector"); + KnnVectorsReader vectorsReader = ((CodecReader) leaf).getVectorReader(); + HnswGraph graph; + if (vectorsReader instanceof HnswGraphProvider hnswGraphProvider) { + graph = hnswGraphProvider.getGraph("vector"); + } else { + graph = null; + } NumericDocValues ids = leaf.getNumericDocValues("id"); long prevValue = -1; boolean usingAltIds = false; @@ -272,10 +280,12 @@ public void testSortOnAddIndicesRandom() throws IOException { assertTrue(sorted_numeric_dv.advanceExact(idNext)); assertTrue(sorted_set_dv.advanceExact(idNext)); assertTrue(binary_sorted_dv.advanceExact(idNext)); - if (denseVectors || prevValue % 2 == 0) { + if (dense || prevValue % 2 == 0) { assertEquals(idNext, valuesIterator.advance(idNext)); - graph.seek(0, valuesIterator.index()); - assertNotEquals(DocIdSetIterator.NO_MORE_DOCS, graph.nextNeighbor()); + if (graph != null) { + graph.seek(0, valuesIterator.index()); + assertNotEquals(DocIdSetIterator.NO_MORE_DOCS, graph.nextNeighbor()); + } } assertEquals(new BytesRef(ids.longValue() + ""), binary_dv.binaryValue()); @@ -289,7 +299,7 @@ public void testSortOnAddIndicesRandom() throws IOException { assertEquals(1, sorted_numeric_dv.docValueCount()); assertEquals(ids.longValue(), sorted_numeric_dv.nextValue()); - if (denseVectors || prevValue % 2 == 0) { + if (dense || prevValue % 2 == 0) { float[] vectorValue = vectorValues.vectorValue(valuesIterator.index()); assertEquals(1, vectorValue.length); assertEquals((float) ids.longValue(), vectorValue[0], 0.001f); @@ -306,9 +316,13 @@ public void testSortOnAddIndicesRandom() throws IOException { leaf.storedFields().document(idNext).get("string_id")); IndexSearcher searcher = new IndexSearcher(r); TopDocs result = - searcher.search(LongPoint.newExactQuery("point_id", ids.longValue()), 1); - assertEquals(1, result.totalHits.value()); - assertEquals(idNext, result.scoreDocs[0].doc); + searcher.search(LongPoint.newExactQuery("point_id", ids.longValue()), 10); + if (dense || ids.longValue() % 3 == 0) { + assertEquals(1, result.totalHits.value()); + assertEquals(idNext, result.scoreDocs[0].doc); + } else { + assertEquals(0, result.totalHits.value()); + } result = searcher.search(new TermQuery(new Term("string_id", "" + ids.longValue())), 1);