Skip to content

Commit

Permalink
Concurrency optimization for graph native loading
Browse files Browse the repository at this point in the history
Signed-off-by: Ganesh Ramadurai <[email protected]>
  • Loading branch information
Gankris96 authored and Ganesh Ramadurai committed Dec 19, 2024
1 parent dc369e6 commit 7cb8710
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ public NativeMemoryAllocation get(NativeMemoryEntryContext<?> nativeMemoryEntryC

// Cache Miss
// Evict before put
nativeMemoryEntryContext.preload();
synchronized (this) {
if (getCacheSizeInKilobytes() + nativeMemoryEntryContext.calculateSizeInKB() >= maxWeight) {
Iterator<String> lruIterator = accessRecencyQueue.iterator();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@
package org.opensearch.knn.index.memory;

import lombok.Getter;
import lombok.Setter;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.Nullable;
import org.opensearch.knn.index.codec.util.NativeMemoryCacheKeyHelper;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.store.IndexInputWithBuffer;

import java.io.IOException;
import java.util.Map;
Expand All @@ -26,7 +30,7 @@
/**
* Encapsulates all information needed to load a component into native memory.
*/
public abstract class NativeMemoryEntryContext<T extends NativeMemoryAllocation> {
public abstract class NativeMemoryEntryContext<T extends NativeMemoryAllocation> implements AutoCloseable {

protected final String key;

Expand Down Expand Up @@ -55,6 +59,18 @@ public String getKey() {
*/
public abstract Integer calculateSizeInKB();

/**
* Preloads the entry by opening the indexInput
*/

public abstract void preload();

/**
* Provides the capability to close the closable objects in the {@link NativeMemoryEntryContext}
*/
@Override
public void close() {}

/**
* Loads entry into memory.
*
Expand All @@ -75,6 +91,18 @@ public static class IndexEntryContext extends NativeMemoryEntryContext<NativeMem
@Getter
private final String modelId;

@Setter
@Getter
private boolean preloaded = false;
@Getter
private int indexSizeKb;

@Getter
private IndexInput readStream;

@Getter
IndexInputWithBuffer indexInputWithBuffer;

/**
* Constructor
*
Expand Down Expand Up @@ -131,10 +159,55 @@ public Integer calculateSizeInKB() {
}
}

@Override
public void preload() {
// Extract vector file name from the given cache key.
// Ex: _0_165_my_field.faiss@1vaqiupVUwvkXAG4Qc/RPg==
final String cacheKey = this.getKey();
final String vectorFileName = NativeMemoryCacheKeyHelper.extractVectorIndexFileName(cacheKey);
if (vectorFileName == null) {
throw new IllegalStateException(
"Invalid cache key was given. The key [" + cacheKey + "] does not contain the corresponding vector file name."
);
}

// Prepare for opening index input from directory.
final Directory directory = this.getDirectory();

// Try to open an index input then pass it down to native engine for loading an index.
try {
indexSizeKb = Math.toIntExact(directory.fileLength(vectorFileName) / 1024);
readStream = directory.openInput(vectorFileName, IOContext.READONCE);
readStream.seek(0);
indexInputWithBuffer = new IndexInputWithBuffer(readStream);
preloaded = true;
} catch (IOException e) {
throw new RuntimeException("Failed to preload the index " + vectorFileName);
}
}

@Override
public NativeMemoryAllocation.IndexAllocation load() throws IOException {
if (!isPreloaded()) {
preload();
}
return indexLoadStrategy.load(this);
}

// close the indexInput
@Override
public void close() {
if (readStream != null) {
try {
readStream.close();
} catch (IOException e) {
throw new RuntimeException(
"Exception while closing the indexInput index [" + openSearchIndexName + "] for loading the graph file.",
e
);
}
}
}
}

public static class TrainingDataEntryContext extends NativeMemoryEntryContext<NativeMemoryAllocation.TrainingDataAllocation> {
Expand Down Expand Up @@ -192,6 +265,11 @@ public Integer calculateSizeInKB() {
return size;
}

@Override
public void preload() {
return;
}

@Override
public NativeMemoryAllocation.TrainingDataAllocation load() {
return trainingLoadStrategy.load(this);
Expand Down Expand Up @@ -278,6 +356,11 @@ public Integer calculateSizeInKB() {
return size;
}

@Override
public void preload() {
return;
}

@Override
public NativeMemoryAllocation.AnonymousAllocation load() throws IOException {
return loadStrategy.load(this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,9 @@

import lombok.extern.log4j.Log4j2;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.opensearch.core.action.ActionListener;
import org.opensearch.knn.index.codec.util.NativeMemoryCacheKeyHelper;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.store.IndexInputWithBuffer;
import org.opensearch.knn.index.util.IndexUtil;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.index.engine.KNNEngine;
Expand Down Expand Up @@ -88,10 +85,16 @@ public NativeMemoryAllocation.IndexAllocation load(NativeMemoryEntryContext.Inde
final int indexSizeKb = Math.toIntExact(directory.fileLength(vectorFileName) / 1024);

// Try to open an index input then pass it down to native engine for loading an index.
try (IndexInput readStream = directory.openInput(vectorFileName, IOContext.READONCE)) {
final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(readStream);
final long indexAddress = JNIService.loadIndex(indexInputWithBuffer, indexEntryContext.getParameters(), knnEngine);

// preload takes care of opening the indexInput file
if (!indexEntryContext.isPreloaded()) {
throw new IllegalStateException("Index [" + indexEntryContext.getOpenSearchIndexName() + "] is not preloaded");
}
try (indexEntryContext) {
final long indexAddress = JNIService.loadIndex(
indexEntryContext.indexInputWithBuffer,
indexEntryContext.getParameters(),
knnEngine
);
return createIndexAllocation(indexEntryContext, knnEngine, indexAddress, indexSizeKb, vectorFileName);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,9 @@ public Integer calculateSizeInKB() {
return size;
}

@Override
public void preload() {}

@Override
public TestNativeMemoryAllocation load() throws IOException {
return new TestNativeMemoryAllocation(size, memoryAddress);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ public void testIndexEntryContext_load() throws IOException {

when(indexLoadStrategy.load(indexEntryContext)).thenReturn(indexAllocation);

// since we are returning mock instance, set indexEntryContext.isPreloaded to true.
indexEntryContext.setPreloaded(true);
assertEquals(indexAllocation, indexEntryContext.load());
}

Expand Down Expand Up @@ -292,6 +294,11 @@ public Integer calculateSizeInKB() {
return size;
}

@Override
public void preload() {
return;
}

@Override
public TestNativeMemoryAllocation load() throws IOException {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ public void testIndexLoadStrategy_load() throws IOException {
);

// Load
NativeMemoryAllocation.IndexAllocation indexAllocation = NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance()
.load(indexEntryContext);
NativeMemoryAllocation.IndexAllocation indexAllocation = indexEntryContext.load();

// Confirm that the file was loaded by querying
float[] query = new float[dimension];
Expand Down Expand Up @@ -115,8 +114,7 @@ public void testLoad_whenFaissBinary_thenSuccess() throws IOException {
);

// Load
NativeMemoryAllocation.IndexAllocation indexAllocation = NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance()
.load(indexEntryContext);
NativeMemoryAllocation.IndexAllocation indexAllocation = indexEntryContext.load();

// Verify
assertTrue(indexAllocation.isBinaryIndex());
Expand Down

0 comments on commit 7cb8710

Please sign in to comment.