Skip to content

Commit

Permalink
Fix shard level rescoring disabled setting flag (#2352)
Browse files Browse the repository at this point in the history
  • Loading branch information
naveentatikonda authored Dec 25, 2024
1 parent 646d8b7 commit c728f02
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 26 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282]
* Allow validation for non knn index only after 2.17.0 (#2315)[https://github.com/opensearch-project/k-NN/pull/2315]
* Release query vector memory after execution (#2346)[https://github.com/opensearch-project/k-NN/pull/2346]
* Fix shard level rescoring disabled setting flag (#2352)[https://github.com/opensearch-project/k-NN/pull/2352]
### Infrastructure
* Updated C++ version in JNI from c++11 to c++17 [#2259](https://github.com/opensearch-project/k-NN/pull/2259)
* Upgrade bytebuddy and objenesis version to match OpenSearch core and, update github ci runner for macos [#2279](https://github.com/opensearch-project/k-NN/pull/2279)
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/opensearch/knn/index/KNNSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ public static Integer getFilteredExactSearchThreshold(final String indexName) {
.getAsInt(ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE);
}

public static boolean isShardLevelRescoringEnabledForDiskBasedVector(String indexName) {
public static boolean isShardLevelRescoringDisabledForDiskBasedVector(String indexName) {
return KNNSettings.state().clusterService.state()
.getMetadata()
.index(indexName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo
if (rescoreContext == null) {
perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, finalK);
} else {
boolean isShardLevelRescoringEnabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(knnQuery.getIndexName());
boolean isShardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(knnQuery.getIndexName());
int dimension = knnQuery.getQueryVector().length;
int firstPassK = rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension);
int firstPassK = rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension);
perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, firstPassK);
if (isShardLevelRescoringEnabled == true) {
if (isShardLevelRescoringDisabled == false) {
ResultUtil.reduceToTopK(perLeafResults, firstPassK);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,17 @@ public static RescoreContext getDefault() {
* based on the vector dimension if shard-level rescoring is disabled.
*
* @param finalK The final number of results to return for the entire shard.
* @param isShardLevelRescoringEnabled A boolean flag indicating whether shard-level rescoring is enabled.
* If true, the dimension-based oversampling logic is bypassed.
* @param isShardLevelRescoringDisabled A boolean flag indicating whether shard-level rescoring is disabled.
* If false, the dimension-based oversampling logic is bypassed.
* @param dimension The dimension of the vector. This is used to determine the oversampling factor when
* shard-level rescoring is disabled.
* @return The number of results to return for the first pass of rescoring, adjusted by the oversample factor.
*/
public int getFirstPassK(int finalK, boolean isShardLevelRescoringEnabled, int dimension) {
public int getFirstPassK(int finalK, boolean isShardLevelRescoringDisabled, int dimension) {
// Only apply default dimension-based oversampling logic when:
// 1. Shard-level rescoring is disabled
// 2. The oversample factor was not provided by the user
if (!isShardLevelRescoringEnabled && !userProvided) {
if (isShardLevelRescoringDisabled && !userProvided) {
// Apply new dimension-based oversampling logic when shard-level rescoring is disabled
if (dimension >= DIMENSION_THRESHOLD_1000) {
oversampleFactor = OVERSAMPLE_FACTOR_1000; // No oversampling for dimensions >= 1000
Expand Down
6 changes: 3 additions & 3 deletions src/test/java/org/opensearch/knn/index/KNNSettingsTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,15 @@ public void testGetEfSearch_whenEFSearchValueSetByUser_thenReturnValue() {
}

@SneakyThrows
public void testShardLevelRescoringEnabled_whenNoValuesProvidedByUser_thenDefaultSettingsUsed() {
public void testShardLevelRescoringDisabled_whenNoValuesProvidedByUser_thenDefaultSettingsUsed() {
Node mockNode = createMockNode(Collections.emptyMap());
mockNode.start();
ClusterService clusterService = mockNode.injector().getInstance(ClusterService.class);
mockNode.client().admin().cluster().state(new ClusterStateRequest()).actionGet();
mockNode.client().admin().indices().create(new CreateIndexRequest(INDEX_NAME)).actionGet();
KNNSettings.state().setClusterService(clusterService);

boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(INDEX_NAME);
boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME);
mockNode.close();
assertFalse(shardLevelRescoringDisabled);
}
Expand All @@ -188,7 +188,7 @@ public void testShardLevelRescoringDisabled_whenValueProvidedByUser_thenSettingA

mockNode.client().admin().indices().updateSettings(new UpdateSettingsRequest(rescoringDisabledSetting, INDEX_NAME)).actionGet();

boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(INDEX_NAME);
boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME);
mockNode.close();
assertEquals(userDefinedRescoringDisabled, shardLevelRescoringDisabled);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ public void testRescoreWhenShardLevelRescoringEnabled() {
) {

// When shard-level re-scoring is enabled
mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(any())).thenReturn(true);
mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(any())).thenReturn(false);

// Mock ResultUtil to return valid TopDocs
mockedResultUtil.when(() -> ResultUtil.resultMapToTopDocs(any(), anyInt()))
Expand Down Expand Up @@ -265,7 +265,7 @@ public void testRescore() {
) {

// When shard-level re-scoring is enabled
mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(any())).thenReturn(true);
mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(any())).thenReturn(false);

mockedResultUtil.when(() -> ResultUtil.reduceToTopK(any(), anyInt())).thenAnswer(InvocationOnMock::callRealMethod);
mockedResultUtil.when(() -> ResultUtil.resultMapToDocIds(any(), anyInt())).thenAnswer(InvocationOnMock::callRealMethod);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,23 @@ public void testGetFirstPassK() {
float oversample = 2.6f;
RescoreContext rescoreContext = RescoreContext.builder().oversampleFactor(oversample).userProvided(true).build();
int finalK = 100;
boolean isShardLevelRescoringEnabled = true;
boolean isShardLevelRescoringDisabled = false;
int dimension = 500;

// Case 1: Test with standard oversample factor when shard-level rescoring is enabled
assertEquals(260, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension));
assertEquals(260, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension));

// Case 2: Test with a very small finalK that should result in a value less than MIN_FIRST_PASS_RESULTS
finalK = 1;
assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension));
assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension));

// Case 3: Test with finalK = 0, should return MIN_FIRST_PASS_RESULTS
finalK = 0;
assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension));
assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension));

// Case 4: Test with finalK = MAX_FIRST_PASS_RESULTS, should cap at MAX_FIRST_PASS_RESULTS
finalK = MAX_FIRST_PASS_RESULTS;
assertEquals(MAX_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension));
assertEquals(MAX_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension));
}

public void testGetFirstPassKWithDimensionBasedOversampling() {
Expand All @@ -42,44 +42,44 @@ public void testGetFirstPassKWithDimensionBasedOversampling() {
// Case 1: Test no oversampling for dimensions >= 1000 when shard-level rescoring is disabled
dimension = 1000;
RescoreContext rescoreContext = RescoreContext.builder().userProvided(false).build(); // Ensuring dimension-based logic applies
assertEquals(100, rescoreContext.getFirstPassK(finalK, false, dimension)); // No oversampling
assertEquals(100, rescoreContext.getFirstPassK(finalK, true, dimension)); // No oversampling

// Case 2: Test 2x oversampling for dimensions >= 768 but < 1000 when shard-level rescoring is disabled
dimension = 800;
rescoreContext = RescoreContext.builder().userProvided(false).build(); // Ensure previous values don't carry over
assertEquals(200, rescoreContext.getFirstPassK(finalK, false, dimension)); // 2x oversampling
assertEquals(200, rescoreContext.getFirstPassK(finalK, true, dimension)); // 2x oversampling

// Case 3: Test 3x oversampling for dimensions < 768 when shard-level rescoring is disabled
dimension = 700;
rescoreContext = RescoreContext.builder().userProvided(false).build(); // Ensure previous values don't carry over
assertEquals(300, rescoreContext.getFirstPassK(finalK, false, dimension)); // 3x oversampling
assertEquals(300, rescoreContext.getFirstPassK(finalK, true, dimension)); // 3x oversampling

// Case 4: Shard-level rescoring enabled, oversample factor should be used as provided by the user (ignore dimension)
rescoreContext = RescoreContext.builder().oversampleFactor(5.0f).userProvided(true).build(); // Provided by user
dimension = 500;
assertEquals(500, rescoreContext.getFirstPassK(finalK, true, dimension)); // User-defined oversample factor should be used
assertEquals(500, rescoreContext.getFirstPassK(finalK, false, dimension)); // User-defined oversample factor should be used

// Case 5: Test finalK where oversampling factor results in a value less than MIN_FIRST_PASS_RESULTS
finalK = 10;
dimension = 700;
rescoreContext = RescoreContext.builder().userProvided(false).build(); // Ensure dimension-based logic applies
assertEquals(100, rescoreContext.getFirstPassK(finalK, false, dimension)); // 3x oversampling results in 30
assertEquals(100, rescoreContext.getFirstPassK(finalK, true, dimension)); // 3x oversampling results in 30
}

public void testGetFirstPassKWithMinPassK() {
float oversample = 0.5f;
RescoreContext rescoreContext = RescoreContext.builder().oversampleFactor(oversample).userProvided(true).build(); // User provided
boolean isShardLevelRescoringEnabled = false;
boolean isShardLevelRescoringDisabled = true;

// Case 1: Test where finalK * oversample is smaller than MIN_FIRST_PASS_RESULTS
int finalK = 10;
int dimension = 700;
assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension));
assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension));

// Case 2: Test where finalK * oversample results in exactly MIN_FIRST_PASS_RESULTS
finalK = 100;
oversample = 1.0f; // This will result in exactly 100 (MIN_FIRST_PASS_RESULTS)
rescoreContext = RescoreContext.builder().oversampleFactor(oversample).userProvided(true).build(); // User provided
assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension));
assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension));
}
}

0 comments on commit c728f02

Please sign in to comment.