From b4fe9cbd5732c00226e4af69eb4048084e102a96 Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Tue, 24 Dec 2024 17:34:07 -0600 Subject: [PATCH] Fix shard level rescoring disabled setting flag Signed-off-by: Naveen Tatikonda --- CHANGELOG.md | 1 + .../org/opensearch/knn/index/KNNSettings.java | 2 +- .../nativelib/NativeEngineKnnVectorQuery.java | 6 ++--- .../index/query/rescore/RescoreContext.java | 8 +++--- .../knn/index/KNNSettingsTests.java | 6 ++--- .../NativeEngineKNNVectorQueryTests.java | 4 +-- .../query/rescore/RescoreContextTests.java | 26 +++++++++---------- 7 files changed, 27 insertions(+), 26 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 926e91633..a19b53fd8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282] * Allow validation for non knn index only after 2.17.0 (#2315)[https://github.com/opensearch-project/k-NN/pull/2315] * Release query vector memory after execution (#2346)[https://github.com/opensearch-project/k-NN/pull/2346] +* Fix shard level rescoring disabled setting flag (#2352)[https://github.com/opensearch-project/k-NN/pull/2352] ### Infrastructure * Updated C++ version in JNI from c++11 to c++17 [#2259](https://github.com/opensearch-project/k-NN/pull/2259) * Upgrade bytebuddy and objenesis version to match OpenSearch core and, update github ci runner for macos [#2279](https://github.com/opensearch-project/k-NN/pull/2279) diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index b81a54124..6dc72a22b 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -577,7 +577,7 @@ public static Integer getFilteredExactSearchThreshold(final String indexName) { .getAsInt(ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE); } - public static boolean isShardLevelRescoringEnabledForDiskBasedVector(String indexName) { + public static boolean isShardLevelRescoringDisabledForDiskBasedVector(String indexName) { return KNNSettings.state().clusterService.state() .getMetadata() .index(indexName) diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java index 47ea215f3..5b4d6e7a1 100644 --- a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java @@ -63,11 +63,11 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo if (rescoreContext == null) { perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, finalK); } else { - boolean isShardLevelRescoringEnabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(knnQuery.getIndexName()); + boolean isShardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(knnQuery.getIndexName()); int dimension = knnQuery.getQueryVector().length; - int firstPassK = rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension); + int firstPassK = rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension); perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, firstPassK); - if (isShardLevelRescoringEnabled == true) { + if (isShardLevelRescoringDisabled == false) { ResultUtil.reduceToTopK(perLeafResults, firstPassK); } diff --git a/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java b/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java index 09aeb7591..03c03a87f 100644 --- a/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java +++ b/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java @@ -61,17 +61,17 @@ public static RescoreContext getDefault() { * based on the vector dimension if shard-level rescoring is disabled. * * @param finalK The final number of results to return for the entire shard. - * @param isShardLevelRescoringEnabled A boolean flag indicating whether shard-level rescoring is enabled. - * If true, the dimension-based oversampling logic is bypassed. + * @param isShardLevelRescoringDisabled A boolean flag indicating whether shard-level rescoring is disabled. + * If false, the dimension-based oversampling logic is bypassed. * @param dimension The dimension of the vector. This is used to determine the oversampling factor when * shard-level rescoring is disabled. * @return The number of results to return for the first pass of rescoring, adjusted by the oversample factor. */ - public int getFirstPassK(int finalK, boolean isShardLevelRescoringEnabled, int dimension) { + public int getFirstPassK(int finalK, boolean isShardLevelRescoringDisabled, int dimension) { // Only apply default dimension-based oversampling logic when: // 1. Shard-level rescoring is disabled // 2. The oversample factor was not provided by the user - if (!isShardLevelRescoringEnabled && !userProvided) { + if (isShardLevelRescoringDisabled && !userProvided) { // Apply new dimension-based oversampling logic when shard-level rescoring is disabled if (dimension >= DIMENSION_THRESHOLD_1000) { oversampleFactor = OVERSAMPLE_FACTOR_1000; // No oversampling for dimensions >= 1000 diff --git a/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java b/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java index c7a8e7ed8..24990dd36 100644 --- a/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java @@ -159,7 +159,7 @@ public void testGetEfSearch_whenEFSearchValueSetByUser_thenReturnValue() { } @SneakyThrows - public void testShardLevelRescoringEnabled_whenNoValuesProvidedByUser_thenDefaultSettingsUsed() { + public void testShardLevelRescoringDisabled_whenNoValuesProvidedByUser_thenDefaultSettingsUsed() { Node mockNode = createMockNode(Collections.emptyMap()); mockNode.start(); ClusterService clusterService = mockNode.injector().getInstance(ClusterService.class); @@ -167,7 +167,7 @@ public void testShardLevelRescoringEnabled_whenNoValuesProvidedByUser_thenDefaul mockNode.client().admin().indices().create(new CreateIndexRequest(INDEX_NAME)).actionGet(); KNNSettings.state().setClusterService(clusterService); - boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(INDEX_NAME); + boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME); mockNode.close(); assertFalse(shardLevelRescoringDisabled); } @@ -188,7 +188,7 @@ public void testShardLevelRescoringDisabled_whenValueProvidedByUser_thenSettingA mockNode.client().admin().indices().updateSettings(new UpdateSettingsRequest(rescoringDisabledSetting, INDEX_NAME)).actionGet(); - boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(INDEX_NAME); + boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME); mockNode.close(); assertEquals(userDefinedRescoringDisabled, shardLevelRescoringDisabled); } diff --git a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java index 789bd1054..87c4a5014 100644 --- a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java @@ -183,7 +183,7 @@ public void testRescoreWhenShardLevelRescoringEnabled() { ) { // When shard-level re-scoring is enabled - mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(any())).thenReturn(true); + mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(any())).thenReturn(false); // Mock ResultUtil to return valid TopDocs mockedResultUtil.when(() -> ResultUtil.resultMapToTopDocs(any(), anyInt())) @@ -265,7 +265,7 @@ public void testRescore() { ) { // When shard-level re-scoring is enabled - mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(any())).thenReturn(true); + mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(any())).thenReturn(false); mockedResultUtil.when(() -> ResultUtil.reduceToTopK(any(), anyInt())).thenAnswer(InvocationOnMock::callRealMethod); mockedResultUtil.when(() -> ResultUtil.resultMapToDocIds(any(), anyInt())).thenAnswer(InvocationOnMock::callRealMethod); diff --git a/src/test/java/org/opensearch/knn/index/query/rescore/RescoreContextTests.java b/src/test/java/org/opensearch/knn/index/query/rescore/RescoreContextTests.java index 2b309e4ab..a0a5cc546 100644 --- a/src/test/java/org/opensearch/knn/index/query/rescore/RescoreContextTests.java +++ b/src/test/java/org/opensearch/knn/index/query/rescore/RescoreContextTests.java @@ -16,23 +16,23 @@ public void testGetFirstPassK() { float oversample = 2.6f; RescoreContext rescoreContext = RescoreContext.builder().oversampleFactor(oversample).userProvided(true).build(); int finalK = 100; - boolean isShardLevelRescoringEnabled = true; + boolean isShardLevelRescoringDisabled = false; int dimension = 500; // Case 1: Test with standard oversample factor when shard-level rescoring is enabled - assertEquals(260, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension)); + assertEquals(260, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension)); // Case 2: Test with a very small finalK that should result in a value less than MIN_FIRST_PASS_RESULTS finalK = 1; - assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension)); + assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension)); // Case 3: Test with finalK = 0, should return MIN_FIRST_PASS_RESULTS finalK = 0; - assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension)); + assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension)); // Case 4: Test with finalK = MAX_FIRST_PASS_RESULTS, should cap at MAX_FIRST_PASS_RESULTS finalK = MAX_FIRST_PASS_RESULTS; - assertEquals(MAX_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension)); + assertEquals(MAX_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension)); } public void testGetFirstPassKWithDimensionBasedOversampling() { @@ -42,44 +42,44 @@ public void testGetFirstPassKWithDimensionBasedOversampling() { // Case 1: Test no oversampling for dimensions >= 1000 when shard-level rescoring is disabled dimension = 1000; RescoreContext rescoreContext = RescoreContext.builder().userProvided(false).build(); // Ensuring dimension-based logic applies - assertEquals(100, rescoreContext.getFirstPassK(finalK, false, dimension)); // No oversampling + assertEquals(100, rescoreContext.getFirstPassK(finalK, true, dimension)); // No oversampling // Case 2: Test 2x oversampling for dimensions >= 768 but < 1000 when shard-level rescoring is disabled dimension = 800; rescoreContext = RescoreContext.builder().userProvided(false).build(); // Ensure previous values don't carry over - assertEquals(200, rescoreContext.getFirstPassK(finalK, false, dimension)); // 2x oversampling + assertEquals(200, rescoreContext.getFirstPassK(finalK, true, dimension)); // 2x oversampling // Case 3: Test 3x oversampling for dimensions < 768 when shard-level rescoring is disabled dimension = 700; rescoreContext = RescoreContext.builder().userProvided(false).build(); // Ensure previous values don't carry over - assertEquals(300, rescoreContext.getFirstPassK(finalK, false, dimension)); // 3x oversampling + assertEquals(300, rescoreContext.getFirstPassK(finalK, true, dimension)); // 3x oversampling // Case 4: Shard-level rescoring enabled, oversample factor should be used as provided by the user (ignore dimension) rescoreContext = RescoreContext.builder().oversampleFactor(5.0f).userProvided(true).build(); // Provided by user dimension = 500; - assertEquals(500, rescoreContext.getFirstPassK(finalK, true, dimension)); // User-defined oversample factor should be used + assertEquals(500, rescoreContext.getFirstPassK(finalK, false, dimension)); // User-defined oversample factor should be used // Case 5: Test finalK where oversampling factor results in a value less than MIN_FIRST_PASS_RESULTS finalK = 10; dimension = 700; rescoreContext = RescoreContext.builder().userProvided(false).build(); // Ensure dimension-based logic applies - assertEquals(100, rescoreContext.getFirstPassK(finalK, false, dimension)); // 3x oversampling results in 30 + assertEquals(100, rescoreContext.getFirstPassK(finalK, true, dimension)); // 3x oversampling results in 30 } public void testGetFirstPassKWithMinPassK() { float oversample = 0.5f; RescoreContext rescoreContext = RescoreContext.builder().oversampleFactor(oversample).userProvided(true).build(); // User provided - boolean isShardLevelRescoringEnabled = false; + boolean isShardLevelRescoringDisabled = true; // Case 1: Test where finalK * oversample is smaller than MIN_FIRST_PASS_RESULTS int finalK = 10; int dimension = 700; - assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension)); + assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension)); // Case 2: Test where finalK * oversample results in exactly MIN_FIRST_PASS_RESULTS finalK = 100; oversample = 1.0f; // This will result in exactly 100 (MIN_FIRST_PASS_RESULTS) rescoreContext = RescoreContext.builder().oversampleFactor(oversample).userProvided(true).build(); // User provided - assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension)); + assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension)); } }