From 182541bc7a5eb8b89e7da3eaaab568425bb6833e Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Thu, 3 Oct 2024 11:55:22 -0700 Subject: [PATCH] Added rescorer in hybrid query (#917) * Initial version for rescorer Signed-off-by: Martin Gaievski (cherry picked from commit 9f4a49a7e45211821d96181ce2a6842af18ce7ea) Signed-off-by: Martin Gaievski --- CHANGELOG.md | 1 + qa/rolling-upgrade/build.gradle | 9 + .../bwc/HybridSearchWithRescoreIT.java | 145 ++++++++ .../search/query/HybridCollectorManager.java | 115 ++++-- .../query/HybridQueryPhaseSearcher.java | 4 +- .../HybridSearchRescoreQueryException.java | 17 + .../neuralsearch/query/HybridQuerySortIT.java | 64 ++++ .../query/HybridCollectorManagerTests.java | 344 ++++++++++++++++++ 8 files changed, 669 insertions(+), 30 deletions(-) create mode 100644 qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchWithRescoreIT.java create mode 100644 src/main/java/org/opensearch/neuralsearch/search/query/exception/HybridSearchRescoreQueryException.java diff --git a/CHANGELOG.md b/CHANGELOG.md index cc72c5d68..569227476 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements - Implement `ignore_missing` field in text chunking processors ([#907](https://github.com/opensearch-project/neural-search/pull/907)) +- Added rescorer in hybrid query ([#917](https://github.com/opensearch-project/neural-search/pull/917)) ### Bug Fixes ### Infrastructure ### Documentation diff --git a/qa/rolling-upgrade/build.gradle b/qa/rolling-upgrade/build.gradle index 7d21c5f9e..617e6d06a 100644 --- a/qa/rolling-upgrade/build.gradle +++ b/qa/rolling-upgrade/build.gradle @@ -76,6 +76,15 @@ task testAgainstOldCluster(type: StandaloneRestIntegTestTask) { } } + // Excluding the test because hybrid query with rescore is not compatible with 2.14 and lower + if (ext.neural_search_bwc_version.startsWith("2.9") || ext.neural_search_bwc_version.startsWith("2.10") + || ext.neural_search_bwc_version.startsWith("2.11") || ext.neural_search_bwc_version.startsWith("2.12") + || ext.neural_search_bwc_version.startsWith("2.13") || ext.neural_search_bwc_version.startsWith("2.14")) { + filter { + excludeTestsMatching "org.opensearch.neuralsearch.bwc.HybridSearchWithRescoreIT.*" + } + } + // Excluding the test because we introduce this feature in 2.13 if (ext.neural_search_bwc_version.startsWith("2.11") || ext.neural_search_bwc_version.startsWith("2.12")){ filter { diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchWithRescoreIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchWithRescoreIT.java new file mode 100644 index 000000000..03b4ae42d --- /dev/null +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchWithRescoreIT.java @@ -0,0 +1,145 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.bwc; + +import org.opensearch.index.query.MatchQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.knn.index.query.rescore.RescoreContext; +import org.opensearch.neuralsearch.query.HybridQueryBuilder; +import org.opensearch.neuralsearch.query.NeuralQueryBuilder; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.opensearch.neuralsearch.util.TestUtils.*; + +public class HybridSearchWithRescoreIT extends AbstractRollingUpgradeTestCase { + + private static final String PIPELINE_NAME = "nlp-hybrid-with_rescore-pipeline"; + private static final String SEARCH_PIPELINE_NAME = "nlp-search-with_rescore-pipeline"; + private static final String TEST_FIELD = "passage_text"; + private static final String TEXT = "Hello world"; + private static final String TEXT_MIXED = "Hi planet"; + private static final String TEXT_UPGRADED = "Hi earth"; + private static final String QUERY = "Hi world"; + private static final int NUM_DOCS_PER_ROUND = 1; + private static final String VECTOR_EMBEDDING_FIELD = "passage_embedding"; + protected static final String RESCORE_QUERY = "hi"; + private static String modelId = ""; + + /** + * Test normalization with hybrid query and rescore. This test is required as rescore will not be compatible with version lower than 2.15 + */ + public void testNormalizationProcessorWithRescore_whenIndexWithMultipleShards_E2EFlow() throws Exception { + waitForClusterHealthGreen(NODES_BWC_CLUSTER); + switch (getClusterType()) { + case OLD: + modelId = uploadTextEmbeddingModel(); + loadModel(modelId); + createPipelineProcessor(modelId, PIPELINE_NAME); + createIndexWithConfiguration( + getIndexNameForTest(), + Files.readString(Path.of(classLoader.getResource("processor/IndexMappings.json").toURI())), + PIPELINE_NAME + ); + addDocument(getIndexNameForTest(), "0", TEST_FIELD, TEXT, null, null); + createSearchPipeline( + SEARCH_PIPELINE_NAME, + DEFAULT_NORMALIZATION_METHOD, + DEFAULT_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.3f, 0.7f })) + ); + break; + case MIXED: + modelId = getModelId(getIngestionPipeline(PIPELINE_NAME), TEXT_EMBEDDING_PROCESSOR); + int totalDocsCountMixed; + if (isFirstMixedRound()) { + totalDocsCountMixed = NUM_DOCS_PER_ROUND; + HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null); + QueryBuilder rescorer = QueryBuilders.matchQuery(TEST_FIELD, RESCORE_QUERY).boost(0.3f); + validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder, rescorer); + addDocument(getIndexNameForTest(), "1", TEST_FIELD, TEXT_MIXED, null, null); + } else { + totalDocsCountMixed = 2 * NUM_DOCS_PER_ROUND; + HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null); + validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder, null); + } + break; + case UPGRADED: + try { + modelId = getModelId(getIngestionPipeline(PIPELINE_NAME), TEXT_EMBEDDING_PROCESSOR); + int totalDocsCountUpgraded = 3 * NUM_DOCS_PER_ROUND; + loadModel(modelId); + addDocument(getIndexNameForTest(), "2", TEST_FIELD, TEXT_UPGRADED, null, null); + HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null); + QueryBuilder rescorer = QueryBuilders.matchQuery(TEST_FIELD, RESCORE_QUERY).boost(0.3f); + validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, rescorer); + hybridQueryBuilder = getQueryBuilder(modelId, Map.of("ef_search", 100), RescoreContext.getDefault()); + validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, rescorer); + } finally { + wipeOfTestResources(getIndexNameForTest(), PIPELINE_NAME, modelId, SEARCH_PIPELINE_NAME); + } + break; + default: + throw new IllegalStateException("Unexpected value: " + getClusterType()); + } + } + + private void validateTestIndexOnUpgrade( + final int numberOfDocs, + final String modelId, + HybridQueryBuilder hybridQueryBuilder, + QueryBuilder rescorer + ) throws Exception { + int docCount = getDocCount(getIndexNameForTest()); + assertEquals(numberOfDocs, docCount); + loadModel(modelId); + Map searchResponseAsMap = search( + getIndexNameForTest(), + hybridQueryBuilder, + rescorer, + 1, + Map.of("search_pipeline", SEARCH_PIPELINE_NAME) + ); + assertNotNull(searchResponseAsMap); + int hits = getHitCount(searchResponseAsMap); + assertEquals(1, hits); + List scoresList = getNormalizationScoreList(searchResponseAsMap); + for (Double score : scoresList) { + assertTrue(0 <= score && score <= 2); + } + } + + private HybridQueryBuilder getQueryBuilder( + final String modelId, + final Map methodParameters, + final RescoreContext rescoreContextForNeuralQuery + ) { + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(); + neuralQueryBuilder.fieldName(VECTOR_EMBEDDING_FIELD); + neuralQueryBuilder.modelId(modelId); + neuralQueryBuilder.queryText(QUERY); + neuralQueryBuilder.k(5); + if (methodParameters != null) { + neuralQueryBuilder.methodParameters(methodParameters); + } + if (Objects.nonNull(rescoreContextForNeuralQuery)) { + neuralQueryBuilder.rescoreContext(rescoreContextForNeuralQuery); + } + + MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder("text", QUERY); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(matchQueryBuilder); + hybridQueryBuilder.add(neuralQueryBuilder); + + return hybridQueryBuilder; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java index 4eb49e845..f9457f6ca 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -6,6 +6,7 @@ import java.util.Locale; import lombok.RequiredArgsConstructor; +import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.Collector; import org.apache.lucene.search.CollectorManager; @@ -33,7 +34,9 @@ import org.opensearch.search.query.MultiCollectorWrapper; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.search.query.ReduceableSearchResult; +import org.opensearch.search.rescore.RescoreContext; import org.opensearch.search.sort.SortAndFormats; +import org.opensearch.neuralsearch.search.query.exception.HybridSearchRescoreQueryException; import java.io.IOException; import java.util.ArrayList; @@ -55,6 +58,7 @@ * In most cases it will be wrapped in MultiCollectorManager. */ @RequiredArgsConstructor +@Log4j2 public abstract class HybridCollectorManager implements CollectorManager { private final int numHits; @@ -67,6 +71,7 @@ public abstract class HybridCollectorManager implements CollectorManager getSearchResults(final List results = new ArrayList<>(); DocValueFormat[] docValueFormats = getSortValueFormats(sortAndFormats); for (HybridSearchCollector collector : hybridSearchCollectors) { - TopDocsAndMaxScore topDocsAndMaxScore = getTopDocsAndAndMaxScore(collector, docValueFormats); + boolean isSortEnabled = docValueFormats != null; + TopDocsAndMaxScore topDocsAndMaxScore = getTopDocsAndAndMaxScore(collector, isSortEnabled); results.add((QuerySearchResult result) -> reduceCollectorResults(result, topDocsAndMaxScore, docValueFormats)); } return results; } - private TopDocsAndMaxScore getTopDocsAndAndMaxScore( - final HybridSearchCollector hybridSearchCollector, - final DocValueFormat[] docValueFormats - ) { - TopDocs newTopDocs; + private TopDocsAndMaxScore getTopDocsAndAndMaxScore(final HybridSearchCollector hybridSearchCollector, final boolean isSortEnabled) { List topDocs = hybridSearchCollector.topDocs(); - if (docValueFormats != null) { - newTopDocs = getNewTopFieldDocs( - getTotalHits(this.trackTotalHitsUpTo, topDocs, hybridSearchCollector.getTotalHits()), - topDocs, - sortAndFormats.sort.getSort() - ); - } else { - newTopDocs = getNewTopDocs(getTotalHits(this.trackTotalHitsUpTo, topDocs, hybridSearchCollector.getTotalHits()), topDocs); + if (isSortEnabled) { + return getSortedTopDocsAndMaxScore(topDocs, hybridSearchCollector); + } + return getTopDocsAndMaxScore(topDocs, hybridSearchCollector); + } + + private TopDocsAndMaxScore getSortedTopDocsAndMaxScore(List topDocs, HybridSearchCollector hybridSearchCollector) { + TopDocs sortedTopDocs = getNewTopFieldDocs( + getTotalHits(this.trackTotalHitsUpTo, topDocs, hybridSearchCollector.getTotalHits()), + topDocs, + sortAndFormats.sort.getSort() + ); + return new TopDocsAndMaxScore(sortedTopDocs, hybridSearchCollector.getMaxScore()); + } + + private TopDocsAndMaxScore getTopDocsAndMaxScore(List topDocs, HybridSearchCollector hybridSearchCollector) { + if (shouldRescore()) { + topDocs = rescore(topDocs); + } + float maxScore = calculateMaxScore(topDocs, hybridSearchCollector.getMaxScore()); + TopDocs finalTopDocs = getNewTopDocs(getTotalHits(this.trackTotalHitsUpTo, topDocs, hybridSearchCollector.getTotalHits()), topDocs); + return new TopDocsAndMaxScore(finalTopDocs, maxScore); + } + + private boolean shouldRescore() { + List rescoreContexts = searchContext.rescore(); + return Objects.nonNull(rescoreContexts) && !rescoreContexts.isEmpty(); + } + + private List rescore(List topDocs) { + List rescoredTopDocs = topDocs; + for (RescoreContext ctx : searchContext.rescore()) { + rescoredTopDocs = rescoredTopDocs(ctx, rescoredTopDocs); + } + return rescoredTopDocs; + } + + /** + * Rescores the top documents using the provided context. The input topDocs may be modified during this process. + */ + private List rescoredTopDocs(final RescoreContext ctx, final List topDocs) { + List result = new ArrayList<>(topDocs.size()); + for (TopDocs topDoc : topDocs) { + try { + result.add(ctx.rescorer().rescore(topDoc, searchContext.searcher(), ctx)); + } catch (IOException exception) { + log.error("rescore failed for hybrid query in collector_manager.reduce call", exception); + throw new HybridSearchRescoreQueryException(exception); + } } - return new TopDocsAndMaxScore(newTopDocs, hybridSearchCollector.getMaxScore()); + return result; + } + + /** + * Calculates the maximum score from the provided TopDocs, considering rescoring. + */ + private float calculateMaxScore(List topDocsList, float initialMaxScore) { + List rescoreContexts = searchContext.rescore(); + if (Objects.nonNull(rescoreContexts) && !rescoreContexts.isEmpty()) { + for (TopDocs topDocs : topDocsList) { + if (Objects.nonNull(topDocs.scoreDocs) && topDocs.scoreDocs.length > 0) { + // first top doc for each sub-query has the max score because top docs are sorted by score desc + initialMaxScore = Math.max(initialMaxScore, topDocs.scoreDocs[0].score); + } + } + } + return initialMaxScore; } private List getHybridSearchCollectors(final Collection collectors) { @@ -415,18 +472,18 @@ public HybridCollectorNonConcurrentManager( int numHits, HitsThresholdChecker hitsThresholdChecker, int trackTotalHitsUpTo, - SortAndFormats sortAndFormats, Weight filteringWeight, - ScoreDoc searchAfter + SearchContext searchContext ) { super( numHits, hitsThresholdChecker, trackTotalHitsUpTo, - sortAndFormats, + searchContext.sort(), filteringWeight, - new TopDocsMerger(sortAndFormats), - (FieldDoc) searchAfter + new TopDocsMerger(searchContext.sort()), + searchContext.searchAfter(), + searchContext ); scoreCollector = Objects.requireNonNull(super.newCollector(), "collector for hybrid query cannot be null"); } @@ -453,18 +510,18 @@ public HybridCollectorConcurrentSearchManager( int numHits, HitsThresholdChecker hitsThresholdChecker, int trackTotalHitsUpTo, - SortAndFormats sortAndFormats, Weight filteringWeight, - ScoreDoc searchAfter + SearchContext searchContext ) { super( numHits, hitsThresholdChecker, trackTotalHitsUpTo, - sortAndFormats, + searchContext.sort(), filteringWeight, - new TopDocsMerger(sortAndFormats), - (FieldDoc) searchAfter + new TopDocsMerger(searchContext.sort()), + searchContext.searchAfter(), + searchContext ); } } diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index 8c7390406..411127507 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -66,7 +66,9 @@ public boolean searchWith( } Query hybridQuery = extractHybridQuery(searchContext, query); QueryPhaseSearcher queryPhaseSearcher = getQueryPhaseSearcher(searchContext); - return queryPhaseSearcher.searchWith(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout); + queryPhaseSearcher.searchWith(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout); + // we decide on rescore later in collector manager + return false; } } diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/exception/HybridSearchRescoreQueryException.java b/src/main/java/org/opensearch/neuralsearch/search/query/exception/HybridSearchRescoreQueryException.java new file mode 100644 index 000000000..34933a8e9 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/query/exception/HybridSearchRescoreQueryException.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query.exception; + +import org.opensearch.OpenSearchException; + +/** + * Exception thrown when there is an issue with the hybrid search rescore query. + */ +public class HybridSearchRescoreQueryException extends OpenSearchException { + + public HybridSearchRescoreQueryException(Throwable cause) { + super("rescore failed for hybrid query", cause); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java index e6440cc61..b5e812780 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java @@ -13,6 +13,7 @@ import lombok.SneakyThrows; import org.junit.BeforeClass; import org.opensearch.client.ResponseException; +import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.MatchQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; @@ -467,6 +468,69 @@ public void testSearchAfter_whenAfterFieldIsNotPassed_thenFail() { } } + @SneakyThrows + public void testSortingWithRescoreWhenConcurrentSegmentSearchEnabledAndDisabled_whenBothSortAndRescorePresent_thenFail() { + try { + prepareResourcesBeforeTestExecution(SHARDS_COUNT_IN_MULTI_NODE_CLUSTER); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderWithMatchTermAndRangeQuery( + "mission", + "part", + LTE_OF_RANGE_IN_HYBRID_QUERY, + GTE_OF_RANGE_IN_HYBRID_QUERY + ); + + Map fieldSortOrderMap = new HashMap<>(); + fieldSortOrderMap.put("stock", SortOrder.DESC); + + List searchAfter = new ArrayList<>(); + searchAfter.add(25); + + QueryBuilder rescoreQuery = QueryBuilders.matchQuery(TEXT_FIELD_1_NAME, TEXT_FIELD_VALUE_1_DUNES); + + assertThrows( + "Cannot use [sort] option in conjunction with [rescore].", + ResponseException.class, + () -> search( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, + hybridQueryBuilder, + rescoreQuery, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null, + createSortBuilders(fieldSortOrderMap, false), + false, + searchAfter, + 0 + ) + ); + + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); + + assertThrows( + "Cannot use [sort] option in conjunction with [rescore].", + ResponseException.class, + () -> search( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, + hybridQueryBuilder, + rescoreQuery, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null, + createSortBuilders(fieldSortOrderMap, false), + false, + searchAfter, + 0 + ) + ); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + } + } + private HybridQueryBuilder createHybridQueryBuilderWithMatchTermAndRangeQuery(String text, String value, int lte, int gte) { MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(TEXT_FIELD_1_NAME, text); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEXT_FIELD_1_NAME, value); diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java index de9c6006b..1d3bc29e9 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java @@ -5,6 +5,8 @@ package org.opensearch.neuralsearch.search.query; import com.carrotsearch.randomizedtesting.RandomizedTest; + +import java.io.IOException; import java.util.Arrays; import lombok.SneakyThrows; import org.apache.lucene.document.FieldType; @@ -12,16 +14,19 @@ import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexReaderContext; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.BulkScorer; import org.apache.lucene.search.Collector; import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.LeafCollector; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; import org.apache.lucene.search.TotalHits; @@ -44,6 +49,7 @@ import org.opensearch.neuralsearch.search.collector.HybridTopScoreDocCollector; import org.opensearch.neuralsearch.search.collector.PagingFieldCollector; import org.opensearch.neuralsearch.search.collector.SimpleFieldCollector; +import org.opensearch.neuralsearch.search.query.exception.HybridSearchRescoreQueryException; import org.opensearch.search.DocValueFormat; import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.SearchContext; @@ -54,11 +60,17 @@ import java.util.List; import java.util.Map; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_DELIMITER; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_START_STOP; + +import org.opensearch.search.rescore.QueryRescorerBuilder; +import org.opensearch.search.rescore.RescoreContext; +import org.opensearch.search.rescore.Rescorer; +import org.opensearch.search.rescore.RescorerBuilder; import org.opensearch.search.sort.SortAndFormats; public class HybridCollectorManagerTests extends OpenSearchQueryTestCase { @@ -70,6 +82,7 @@ public class HybridCollectorManagerTests extends OpenSearchQueryTestCase { private static final String QUERY1 = "hello"; private static final String QUERY2 = "hi"; private static final float DELTA_FOR_ASSERTION = 0.001f; + protected static final String QUERY3 = "everyone"; @SneakyThrows public void testNewCollector_whenNotConcurrentSearch_thenSuccessful() { @@ -734,4 +747,335 @@ public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedD reader2.close(); directory2.close(); } + + @SneakyThrows + public void testReduceAndRescore_whenMatchedDocsAndRescoreContextPresent_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + HybridQuery hybridQueryWithTerm = new HybridQuery( + List.of( + QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), + QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext) + ) + ); + when(searchContext.query()).thenReturn(hybridQueryWithTerm); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexReader.numDocs()).thenReturn(3); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + when(searchContext.size()).thenReturn(2); + IndexReaderContext indexReaderContext = mock(IndexReaderContext.class); + when(indexReader.getContext()).thenReturn(indexReaderContext); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + int docId1 = RandomizedTest.randomInt(); + int docId2 = RandomizedTest.randomInt(); + int docId3 = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId2, TEST_DOC_TEXT2, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId3, TEST_DOC_TEXT3, ft)); + w.flush(); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + IndexSearcher searcher = newSearcher(reader); + + RescorerBuilder rescorerBuilder = new QueryRescorerBuilder(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2)); + RescoreContext rescoreContext = rescorerBuilder.buildContext(mockQueryShardContext); + List rescoreContexts = List.of(rescoreContext); + when(searchContext.rescore()).thenReturn(rescoreContexts); + when(indexReader.leaves()).thenReturn(reader.leaves()); + Weight rescoreWeight = mock(Weight.class); + Scorer rescoreScorer = mock(Scorer.class); + when(rescoreWeight.scorer(any())).thenReturn(rescoreScorer); + when(rescoreScorer.docID()).thenReturn(1); + DocIdSetIterator iterator = mock(DocIdSetIterator.class); + when(rescoreScorer.iterator()).thenReturn(iterator); + when(rescoreScorer.score()).thenReturn(0.9f); + when(indexSearcher.createWeight(any(), eq(ScoreMode.COMPLETE), eq(1f))).thenReturn(rescoreWeight); + + CollectorManager hybridCollectorManager1 = HybridCollectorManager.createHybridCollectorManager(searchContext); + HybridTopScoreDocCollector collector = (HybridTopScoreDocCollector) hybridCollectorManager1.newCollector(); + + QueryBuilder postFilterQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); + + Query pfQuery = postFilterQuery.toQuery(mockQueryShardContext); + ParsedQuery parsedQuery = new ParsedQuery(pfQuery); + searchContext.parsedQuery(parsedQuery); + when(searchContext.parsedPostFilter()).thenReturn(parsedQuery); + when(indexSearcher.rewrite(pfQuery)).thenReturn(pfQuery); + Weight postFilterWeight = mock(Weight.class); + when(indexSearcher.createWeight(pfQuery, ScoreMode.COMPLETE_NO_SCORES, 1f)).thenReturn(postFilterWeight); + + CollectorManager hybridCollectorManager2 = HybridCollectorManager.createHybridCollectorManager(searchContext); + FilteredCollector filteredCollector = (FilteredCollector) hybridCollectorManager2.newCollector(); + + Weight weight = new HybridQueryWeight(hybridQueryWithTerm, searcher, ScoreMode.TOP_SCORES, BoostingQueryBuilder.DEFAULT_BOOST); + collector.setWeight(weight); + filteredCollector.setWeight(weight); + LeafReaderContext leafReaderContext = searcher.getIndexReader().leaves().get(0); + LeafCollector leafCollector = collector.getLeafCollector(leafReaderContext); + LeafCollector filteredCollectorLeafCollector = filteredCollector.getLeafCollector(leafReaderContext); + BulkScorer scorer = weight.bulkScorer(leafReaderContext); + scorer.score(leafCollector, leafReaderContext.reader().getLiveDocs()); + leafCollector.finish(); + scorer.score(filteredCollectorLeafCollector, leafReaderContext.reader().getLiveDocs()); + filteredCollectorLeafCollector.finish(); + + Object results1 = hybridCollectorManager1.reduce(List.of()); + Object results2 = hybridCollectorManager2.reduce(List.of()); + + assertNotNull(results1); + assertNotNull(results2); + ReduceableSearchResult reduceableSearchResult = ((ReduceableSearchResult) results1); + QuerySearchResult querySearchResult = new QuerySearchResult(); + reduceableSearchResult.reduce(querySearchResult); + TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); + + assertNotNull(topDocsAndMaxScore); + assertEquals(2, topDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocsAndMaxScore.topDocs.totalHits.relation); + float maxScore = topDocsAndMaxScore.maxScore; + assertTrue(maxScore > 0); + ScoreDoc[] scoreDocs = topDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(6, scoreDocs.length); + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[0].score, DELTA_FOR_ASSERTION); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[1].score, DELTA_FOR_ASSERTION); + assertTrue(maxScore >= scoreDocs[2].score); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[3].score, DELTA_FOR_ASSERTION); + assertEquals(maxScore, scoreDocs[4].score, DELTA_FOR_ASSERTION); + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[5].score, DELTA_FOR_ASSERTION); + + w.close(); + reader.close(); + directory.close(); + } + + @SneakyThrows + public void testRescoreWithConcurrentSegmentSearch_whenMatchedDocsAndRescore_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + HybridQuery hybridQueryWithTerm = new HybridQuery( + List.of( + QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), + QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext), + QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY3).toQuery(mockQueryShardContext) + ) + ); + when(searchContext.query()).thenReturn(hybridQueryWithTerm); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexReader.numDocs()).thenReturn(2); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + when(searchContext.size()).thenReturn(1); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(true); + // index segment 1 + Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + int docId1 = RandomizedTest.randomInt(); + int docId2 = RandomizedTest.randomInt(); + int docId3 = RandomizedTest.randomInt(); + + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId2, TEST_DOC_TEXT2, ft)); + w.flush(); + w.commit(); + + // index segment 2 + SearchContext searchContext2 = mock(SearchContext.class); + + ContextIndexSearcher indexSearcher2 = mock(ContextIndexSearcher.class); + IndexReader indexReader2 = mock(IndexReader.class); + when(indexReader2.numDocs()).thenReturn(1); + when(indexSearcher2.getIndexReader()).thenReturn(indexReader); + when(searchContext2.searcher()).thenReturn(indexSearcher2); + when(searchContext2.size()).thenReturn(1); + + when(searchContext2.queryCollectorManagers()).thenReturn(new HashMap<>()); + when(searchContext2.shouldUseConcurrentSearch()).thenReturn(true); + + Directory directory2 = newDirectory(); + final IndexWriter w2 = new IndexWriter(directory2, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft2 = new FieldType(TextField.TYPE_NOT_STORED); + ft2.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft2.setOmitNorms(random().nextBoolean()); + ft2.freeze(); + + w2.addDocument(getDocument(TEXT_FIELD_NAME, docId3, TEST_DOC_TEXT3, ft)); + w2.flush(); + w2.commit(); + + IndexReader reader1 = DirectoryReader.open(w); + IndexSearcher searcher1 = newSearcher(reader1); + IndexReader reader2 = DirectoryReader.open(w2); + IndexSearcher searcher2 = newSearcher(reader2); + + List leafReaderContexts = reader1.leaves(); + IndexReaderContext indexReaderContext = mock(IndexReaderContext.class); + when(indexReader.getContext()).thenReturn(indexReaderContext); + when(indexReader.leaves()).thenReturn(leafReaderContexts); + // set up rescorer in a way that it boosts second documents from the first segment + RescorerBuilder rescorerBuilder = new QueryRescorerBuilder(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2)); + RescoreContext rescoreContext = rescorerBuilder.buildContext(mockQueryShardContext); + List rescoreContexts = List.of(rescoreContext); + when(searchContext.rescore()).thenReturn(rescoreContexts); + Weight rescoreWeight = mock(Weight.class); + Scorer rescoreScorer = mock(Scorer.class); + when(rescoreWeight.scorer(any())).thenReturn(rescoreScorer); + when(rescoreScorer.docID()).thenReturn(1); + DocIdSetIterator iterator = mock(DocIdSetIterator.class); + when(rescoreScorer.iterator()).thenReturn(iterator); + when(rescoreScorer.score()).thenReturn(0.9f); + when(indexSearcher.createWeight(any(), eq(ScoreMode.COMPLETE), eq(1f))).thenReturn(rescoreWeight); + + CollectorManager hybridCollectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); + HybridTopScoreDocCollector collector1 = (HybridTopScoreDocCollector) hybridCollectorManager.newCollector(); + HybridTopScoreDocCollector collector2 = (HybridTopScoreDocCollector) hybridCollectorManager.newCollector(); + + Weight weight1 = new HybridQueryWeight(hybridQueryWithTerm, searcher1, ScoreMode.TOP_SCORES, BoostingQueryBuilder.DEFAULT_BOOST); + Weight weight2 = new HybridQueryWeight(hybridQueryWithTerm, searcher2, ScoreMode.TOP_SCORES, BoostingQueryBuilder.DEFAULT_BOOST); + collector1.setWeight(weight1); + collector2.setWeight(weight2); + + LeafReaderContext leafReaderContext = searcher1.getIndexReader().leaves().get(0); + LeafCollector leafCollector1 = collector1.getLeafCollector(leafReaderContext); + BulkScorer scorer = weight1.bulkScorer(leafReaderContext); + scorer.score(leafCollector1, leafReaderContext.reader().getLiveDocs()); + leafCollector1.finish(); + + LeafReaderContext leafReaderContext2 = searcher2.getIndexReader().leaves().get(0); + LeafCollector leafCollector2 = collector2.getLeafCollector(leafReaderContext2); + BulkScorer scorer2 = weight2.bulkScorer(leafReaderContext2); + scorer2.score(leafCollector2, leafReaderContext2.reader().getLiveDocs()); + leafCollector2.finish(); + + Object results = hybridCollectorManager.reduce(List.of(collector1, collector2)); + + // assert that second search hit in result has the max score due to boots from rescorer + assertNotNull(results); + ReduceableSearchResult reduceableSearchResult = ((ReduceableSearchResult) results); + QuerySearchResult querySearchResult = new QuerySearchResult(); + reduceableSearchResult.reduce(querySearchResult); + TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); + + assertNotNull(topDocsAndMaxScore); + assertEquals(3, topDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocsAndMaxScore.topDocs.totalHits.relation); + float maxScore = topDocsAndMaxScore.maxScore; + assertTrue(maxScore > 0); + ScoreDoc[] scoreDocs = topDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(8, scoreDocs.length); + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[0].score, DELTA_FOR_ASSERTION); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[1].score, DELTA_FOR_ASSERTION); + assertTrue(maxScore > scoreDocs[2].score); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[3].score, DELTA_FOR_ASSERTION); + assertEquals(maxScore, scoreDocs[4].score, DELTA_FOR_ASSERTION); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[5].score, DELTA_FOR_ASSERTION); + assertTrue(maxScore > scoreDocs[6].score); + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[7].score, DELTA_FOR_ASSERTION); + + // release resources + w.close(); + reader1.close(); + directory.close(); + w2.close(); + reader2.close(); + directory2.close(); + } + + @SneakyThrows + public void testReduceAndRescore_whenRescorerThrowsException_thenFail() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + HybridQuery hybridQueryWithTerm = new HybridQuery( + List.of( + QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), + QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext) + ) + ); + when(searchContext.query()).thenReturn(hybridQueryWithTerm); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexReader.numDocs()).thenReturn(3); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + when(searchContext.size()).thenReturn(2); + IndexReaderContext indexReaderContext = mock(IndexReaderContext.class); + when(indexReader.getContext()).thenReturn(indexReaderContext); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + int docId1 = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.flush(); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + IndexSearcher searcher = newSearcher(reader); + + RescoreContext rescoreContext = mock(RescoreContext.class); + Rescorer rescorer = mock(Rescorer.class); + when(rescoreContext.rescorer()).thenReturn(rescorer); + when(rescorer.rescore(any(), any(), any())).thenThrow(new IOException("something happened with rescorer")); + List rescoreContexts = List.of(rescoreContext); + when(searchContext.rescore()).thenReturn(rescoreContexts); + + CollectorManager hybridCollectorManager1 = HybridCollectorManager.createHybridCollectorManager(searchContext); + HybridTopScoreDocCollector collector = (HybridTopScoreDocCollector) hybridCollectorManager1.newCollector(); + + Weight weight = new HybridQueryWeight(hybridQueryWithTerm, searcher, ScoreMode.TOP_SCORES, BoostingQueryBuilder.DEFAULT_BOOST); + collector.setWeight(weight); + + LeafReaderContext leafReaderContext = searcher.getIndexReader().leaves().get(0); + LeafCollector leafCollector = collector.getLeafCollector(leafReaderContext); + + BulkScorer scorer = weight.bulkScorer(leafReaderContext); + scorer.score(leafCollector, leafReaderContext.reader().getLiveDocs()); + leafCollector.finish(); + + expectThrows(HybridSearchRescoreQueryException.class, () -> hybridCollectorManager1.reduce(List.of())); + + // release resources + w.close(); + reader.close(); + directory.close(); + } }