Skip to content

Commit

Permalink
[9.x] Add a MemorySegment Vector scorer - for scoring without copying…
Browse files Browse the repository at this point in the history
… on-heap (#13402)

Add a MemorySegment Vector scorer - for scoring without copying on-heap.

The vector scorer loads values directly from the backing memory segment when available. Otherwise, if the vector data spans across segments the scorer copies the vector data on-heap.

A benchmark shows ~2x performance improvement of this scorer over the default copy-on-heap scorer.

The scorer currently only operates on vectors with an element size of byte. We can evaluate if and how to support floats separately.
  • Loading branch information
ChrisHegarty authored May 21, 2024
1 parent d333dee commit e913d5a
Show file tree
Hide file tree
Showing 30 changed files with 2,049 additions and 31 deletions.
2 changes: 1 addition & 1 deletion gradle/generation/extract-jdk-apis/ExtractJdkApis.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public final class ExtractJdkApis {
static final Map<Integer,List<String>> CLASSFILE_PATTERNS = Map.of(
19, List.of(PATTERN_PANAMA_FOREIGN),
20, List.of(PATTERN_PANAMA_FOREIGN, PATTERN_VECTOR_VM_INTERNALS, PATTERN_VECTOR_INCUBATOR),
21, List.of(PATTERN_PANAMA_FOREIGN)
21, List.of(PATTERN_PANAMA_FOREIGN, PATTERN_VECTOR_VM_INTERNALS, PATTERN_VECTOR_INCUBATOR)
);

public static void main(String... args) throws IOException {
Expand Down
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ Optimizations

* GITHUB#13327: Reduce memory usage of field maps in FieldInfos and BlockTree TermsReader. (Bruno Roustant, David Smiley)

* GITHUB#13339: Add a MemorySegment Vector scorer - for scoring without copying on-heap (Chris Hegarty)

Bug Fixes
---------------------

Expand Down
Binary file modified lucene/core/src/generated/jdk/jdk21.apijar
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
* @lucene.experimental
*/
public class DefaultFlatVectorScorer implements FlatVectorsScorer {

public static final DefaultFlatVectorScorer INSTANCE = new DefaultFlatVectorScorer();

@Override
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.lucene.codecs.hnsw;

import org.apache.lucene.internal.vectorization.VectorizationProvider;

/**
* Utilities for {@link FlatVectorsScorer}.
*
* @lucene.experimental
*/
public final class FlatVectorScorerUtil {

private static final VectorizationProvider IMPL = VectorizationProvider.getInstance();

private FlatVectorScorerUtil() {}

/**
* Returns a FlatVectorsScorer that supports the Lucene99 format. Scorers retrieved through this
* method may be optimized on certain platforms. Otherwise, a DefaultFlatVectorScorer is returned.
*/
public static FlatVectorsScorer getLucene99FlatVectorsScorer() {
return IMPL.getLucene99FlatVectorsScorer();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.index.MergePolicy;
Expand Down Expand Up @@ -139,7 +139,7 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {

/** The format for storing, reading, merging vectors on disk */
private static final FlatVectorsFormat flatVectorsFormat =
new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer());
new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());

private final int numMergeWorkers;
private final TaskExecutor mergeExec;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.io.IOException;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
Expand Down Expand Up @@ -48,7 +49,7 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
static final String VECTOR_DATA_EXTENSION = "veq";

private static final FlatVectorsFormat rawVectorFormat =
new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer());
new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());

/** The minimum confidence interval */
private static final float MINIMUM_CONFIDENCE_INTERVAL = 0.9f;
Expand Down Expand Up @@ -101,7 +102,8 @@ public Lucene99ScalarQuantizedVectorsFormat(
this.bits = (byte) bits;
this.confidenceInterval = confidenceInterval;
this.compress = compress;
this.flatVectorScorer = new Lucene99ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer());
this.flatVectorScorer =
new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE);
}

public static float calculateDefaultConfidenceInterval(int vectorDimension) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.lucene.internal.tests;

import org.apache.lucene.store.FilterIndexInput;

/**
* Access to {@link org.apache.lucene.store.FilterIndexInput} internals exposed to the test
* framework.
*
* @lucene.internal
*/
public interface FilterIndexInputAccess {
/** Adds the given test FilterIndexInput class. */
void addTestFilterType(Class<? extends FilterIndexInput> cls);
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.SegmentReader;
import org.apache.lucene.store.FilterIndexInput;

/**
* A set of static methods returning accessors for internal, package-private functionality in
Expand All @@ -48,12 +49,14 @@ public final class TestSecrets {
ensureInitialized.accept(ConcurrentMergeScheduler.class);
ensureInitialized.accept(SegmentReader.class);
ensureInitialized.accept(IndexWriter.class);
ensureInitialized.accept(FilterIndexInput.class);
}

private static IndexPackageAccess indexPackageAccess;
private static ConcurrentMergeSchedulerAccess cmsAccess;
private static SegmentReaderAccess segmentReaderAccess;
private static IndexWriterAccess indexWriterAccess;
private static FilterIndexInputAccess filterIndexInputAccess;

private TestSecrets() {}

Expand Down Expand Up @@ -81,6 +84,12 @@ public static IndexWriterAccess getIndexWriterAccess() {
return Objects.requireNonNull(indexWriterAccess);
}

/** Return the accessor to internal secrets for an {@link FilterIndexInput}. */
public static FilterIndexInputAccess getFilterInputIndexAccess() {
ensureCaller();
return Objects.requireNonNull(filterIndexInputAccess);
}

/** For internal initialization only. */
public static void setIndexWriterAccess(IndexWriterAccess indexWriterAccess) {
ensureNull(TestSecrets.indexWriterAccess);
Expand All @@ -105,6 +114,12 @@ public static void setSegmentReaderAccess(SegmentReaderAccess segmentReaderAcces
TestSecrets.segmentReaderAccess = segmentReaderAccess;
}

/** For internal initialization only. */
public static void setFilterInputIndexAccess(FilterIndexInputAccess filterIndexInputAccess) {
ensureNull(TestSecrets.filterIndexInputAccess);
TestSecrets.filterIndexInputAccess = filterIndexInputAccess;
}

private static void ensureNull(Object ob) {
if (ob != null) {
throw new AssertionError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.lucene.internal.vectorization;

import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;

/** Default provider returning scalar implementations. */
final class DefaultVectorizationProvider extends VectorizationProvider {

Expand All @@ -30,4 +33,9 @@ final class DefaultVectorizationProvider extends VectorizationProvider {
public VectorUtilSupport getVectorUtilSupport() {
return vectorUtilSupport;
}

@Override
public FlatVectorsScorer getLucene99FlatVectorsScorer() {
return DefaultFlatVectorScorer.INSTANCE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.util.function.Predicate;
import java.util.logging.Logger;
import java.util.stream.Stream;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.util.Constants;
import org.apache.lucene.util.VectorUtil;

Expand Down Expand Up @@ -93,6 +94,9 @@ public static VectorizationProvider getInstance() {
*/
public abstract VectorUtilSupport getVectorUtilSupport();

/** Returns a FlatVectorsScorer that supports the Lucene99 format. */
public abstract FlatVectorsScorer getLucene99FlatVectorsScorer();

// *** Lookup mechanism: ***

private static final Logger LOG = Logger.getLogger(VectorizationProvider.class.getName());
Expand Down Expand Up @@ -199,7 +203,10 @@ private static boolean isAffectedByJDK8301190() {
}

// add all possible callers here as FQCN:
private static final Set<String> VALID_CALLERS = Set.of("org.apache.lucene.util.VectorUtil");
private static final Set<String> VALID_CALLERS =
Set.of(
"org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil",
"org.apache.lucene.util.VectorUtil");

private static void ensureCaller() {
final boolean validCaller =
Expand Down
19 changes: 19 additions & 0 deletions lucene/core/src/java/org/apache/lucene/store/FilterIndexInput.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package org.apache.lucene.store;

import java.io.IOException;
import java.util.concurrent.CopyOnWriteArrayList;
import org.apache.lucene.internal.tests.TestSecrets;

/**
* IndexInput implementation that delegates calls to another directory. This class can be used to
Expand All @@ -29,6 +31,12 @@
*/
public class FilterIndexInput extends IndexInput {

static final CopyOnWriteArrayList<Class<?>> TEST_FILTER_INPUTS = new CopyOnWriteArrayList<>();

static {
TestSecrets.setFilterInputIndexAccess(TEST_FILTER_INPUTS::add);
}

/**
* Unwraps all FilterIndexInputs until the first non-FilterIndexInput IndexInput instance and
* returns it
Expand All @@ -40,6 +48,17 @@ public static IndexInput unwrap(IndexInput in) {
return in;
}

/**
* Unwraps all test FilterIndexInputs until the first non-test FilterIndexInput IndexInput
* instance and returns it
*/
public static IndexInput unwrapOnlyTest(IndexInput in) {
while (in instanceof FilterIndexInput && TEST_FILTER_INPUTS.contains(in.getClass())) {
in = ((FilterIndexInput) in).in;
}
return in;
}

protected final IndexInput in;

/** Creates a FilterIndexInput with a resource description and wrapped delegate IndexInput */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import java.util.Locale;
import java.util.logging.Logger;
import jdk.incubator.vector.FloatVector;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.util.Constants;
import org.apache.lucene.util.SuppressForbidden;

Expand Down Expand Up @@ -73,4 +75,10 @@ private static <T> T doPrivileged(PrivilegedAction<T> action) {
public VectorUtilSupport getVectorUtilSupport() {
return vectorUtilSupport;
}

// Use the default scorer on JDK 20
@Override
public FlatVectorsScorer getLucene99FlatVectorsScorer() {
return DefaultFlatVectorScorer.INSTANCE;
}
}
Loading

0 comments on commit e913d5a

Please sign in to comment.