Skip to content

Commit

Permalink
Added rescorer in hybrid query (#917)
Browse files Browse the repository at this point in the history
* Initial version for rescorer

Signed-off-by: Martin Gaievski <[email protected]>
(cherry picked from commit 9f4a49a)
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Oct 4, 2024
1 parent 8a786fe commit c7ec5a4
Show file tree
Hide file tree
Showing 8 changed files with 620 additions and 41 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
- Implement `ignore_missing` field in text chunking processors ([#907](https://github.com/opensearch-project/neural-search/pull/907))
- Added rescorer in hybrid query ([#917](https://github.com/opensearch-project/neural-search/pull/917))
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
24 changes: 24 additions & 0 deletions qa/rolling-upgrade/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ task testAgainstOldCluster(type: StandaloneRestIntegTestTask) {
}
}

// Excluding the test because hybrid query with rescore is not compatible with 2.14 and lower
if (ext.neural_search_bwc_version.startsWith("2.11") || ext.neural_search_bwc_version.startsWith("2.12")
|| ext.neural_search_bwc_version.startsWith("2.13") || ext.neural_search_bwc_version.startsWith("2.14")) {
filter {
excludeTestsMatching "org.opensearch.neuralsearch.bwc.HybridSearchIT.testNormalizationProcessorWithRescore_whenIndexWithMultipleShards_E2EFlow"
}
}

// Excluding the test because we introduce this feature in 2.13
if (ext.neural_search_bwc_version.startsWith("2.11") || ext.neural_search_bwc_version.startsWith("2.12")){
filter {
Expand Down Expand Up @@ -144,6 +152,14 @@ task testAgainstOneThirdUpgradedCluster(type: StandaloneRestIntegTestTask) {
}
}

// Excluding the test because hybrid query with rescore is not compatible with 2.14 and lower
if (ext.neural_search_bwc_version.startsWith("2.11") || ext.neural_search_bwc_version.startsWith("2.12")
|| ext.neural_search_bwc_version.startsWith("2.13") || ext.neural_search_bwc_version.startsWith("2.14")) {
filter {
excludeTestsMatching "org.opensearch.neuralsearch.bwc.HybridSearchIT.testNormalizationProcessorWithRescore_whenIndexWithMultipleShards_E2EFlow"
}
}

// Excluding the test because we introduce this feature in 2.13
if (ext.neural_search_bwc_version.startsWith("2.11") || ext.neural_search_bwc_version.startsWith("2.12")){
filter {
Expand Down Expand Up @@ -211,6 +227,14 @@ task testAgainstTwoThirdsUpgradedCluster(type: StandaloneRestIntegTestTask) {
}
}

// Excluding the test because hybrid query with rescore is not compatible with 2.14 and lower
if (ext.neural_search_bwc_version.startsWith("2.11") || ext.neural_search_bwc_version.startsWith("2.12")
|| ext.neural_search_bwc_version.startsWith("2.13") || ext.neural_search_bwc_version.startsWith("2.14")) {
filter {
excludeTestsMatching "org.opensearch.neuralsearch.bwc.HybridSearchIT.testNormalizationProcessorWithRescore_whenIndexWithMultipleShards_E2EFlow"
}
}

// Excluding the test because we introduce this feature in 2.13
if (ext.neural_search_bwc_version.startsWith("2.11") || ext.neural_search_bwc_version.startsWith("2.12")){
filter {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import org.opensearch.index.query.MatchQueryBuilder;
import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER;
import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_WEIGHTS;
Expand All @@ -17,6 +19,8 @@
import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_COMBINATION_METHOD;
import static org.opensearch.neuralsearch.util.TestUtils.getModelId;

import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
Expand All @@ -25,12 +29,16 @@ public class HybridSearchIT extends AbstractRollingUpgradeTestCase {

private static final String PIPELINE_NAME = "nlp-hybrid-pipeline";
private static final String SEARCH_PIPELINE_NAME = "nlp-search-pipeline";
private static final String PIPELINE_RESCORE_NAME = "nlp-hybrid-with_rescore-pipeline";
private static final String SEARCH_PIPELINE_RESCORE_NAME = "nlp-search-with_rescore-pipeline";
private static final String TEST_FIELD = "passage_text";
private static final String TEXT = "Hello world";
private static final String TEXT_MIXED = "Hi planet";
private static final String TEXT_UPGRADED = "Hi earth";
private static final String QUERY = "Hi world";
private static final int NUM_DOCS_PER_ROUND = 1;
private static final String VECTOR_EMBEDDING_FIELD = "passage_embedding";
protected static final String RESCORE_QUERY = "hi";
private static String modelId = "";

// Test rolling-upgrade normalization processor when index with multiple shards
Expand Down Expand Up @@ -62,12 +70,12 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
if (isFirstMixedRound()) {
totalDocsCountMixed = NUM_DOCS_PER_ROUND;
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder);
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder, null);
addDocument(getIndexNameForTest(), "1", TEST_FIELD, TEXT_MIXED, null, null);
} else {
totalDocsCountMixed = 2 * NUM_DOCS_PER_ROUND;
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder);
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder, null);
}
break;
case UPGRADED:
Expand All @@ -77,9 +85,9 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
loadModel(modelId);
addDocument(getIndexNameForTest(), "2", TEST_FIELD, TEXT_UPGRADED, null, null);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder);
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, null);
hybridQueryBuilder = getQueryBuilder(modelId, Map.of("ef_search", 100), RescoreContext.getDefault());
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder);
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, null);
} finally {
wipeOfTestResources(getIndexNameForTest(), PIPELINE_NAME, modelId, SEARCH_PIPELINE_NAME);
}
Expand All @@ -89,15 +97,77 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
}
}

private void validateTestIndexOnUpgrade(final int numberOfDocs, final String modelId, HybridQueryBuilder hybridQueryBuilder)
throws Exception {
/**
* Test normalization with hybrid query and rescore. This test is required as rescore will not be compatible with version lower than 2.15
*/
public void testNormalizationProcessorWithRescore_whenIndexWithMultipleShards_E2EFlow() throws Exception {
waitForClusterHealthGreen(NODES_BWC_CLUSTER);
switch (getClusterType()) {
case OLD:
modelId = uploadTextEmbeddingModel();
loadModel(modelId);
createPipelineProcessor(modelId, PIPELINE_RESCORE_NAME);
createIndexWithConfiguration(
getIndexNameForTest(),
Files.readString(Path.of(classLoader.getResource("processor/IndexMappings.json").toURI())),
PIPELINE_RESCORE_NAME
);
addDocument(getIndexNameForTest(), "0", TEST_FIELD, TEXT, null, null);
createSearchPipeline(
SEARCH_PIPELINE_RESCORE_NAME,
DEFAULT_NORMALIZATION_METHOD,
DEFAULT_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.3f, 0.7f }))
);
break;
case MIXED:
modelId = getModelId(getIngestionPipeline(PIPELINE_RESCORE_NAME), TEXT_EMBEDDING_PROCESSOR);
int totalDocsCountMixed;
if (isFirstMixedRound()) {
totalDocsCountMixed = NUM_DOCS_PER_ROUND;
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
QueryBuilder rescorer = QueryBuilders.matchQuery(TEST_FIELD, RESCORE_QUERY).boost(0.3f);
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder, rescorer);
addDocument(getIndexNameForTest(), "1", TEST_FIELD, TEXT_MIXED, null, null);
} else {
totalDocsCountMixed = 2 * NUM_DOCS_PER_ROUND;
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder, null);
}
break;
case UPGRADED:
try {
modelId = getModelId(getIngestionPipeline(PIPELINE_RESCORE_NAME), TEXT_EMBEDDING_PROCESSOR);
int totalDocsCountUpgraded = 3 * NUM_DOCS_PER_ROUND;
loadModel(modelId);
addDocument(getIndexNameForTest(), "2", TEST_FIELD, TEXT_UPGRADED, null, null);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
QueryBuilder rescorer = QueryBuilders.matchQuery(TEST_FIELD, RESCORE_QUERY).boost(0.3f);
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, rescorer);
hybridQueryBuilder = getQueryBuilder(modelId, Map.of("ef_search", 100), RescoreContext.getDefault());
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, rescorer);
} finally {
wipeOfTestResources(getIndexNameForTest(), PIPELINE_RESCORE_NAME, modelId, SEARCH_PIPELINE_RESCORE_NAME);
}
break;
default:
throw new IllegalStateException("Unexpected value: " + getClusterType());
}
}

private void validateTestIndexOnUpgrade(
final int numberOfDocs,
final String modelId,
HybridQueryBuilder hybridQueryBuilder,
QueryBuilder rescorer
) throws Exception {
int docCount = getDocCount(getIndexNameForTest());
assertEquals(numberOfDocs, docCount);
loadModel(modelId);
Map<String, Object> searchResponseAsMap = search(
getIndexNameForTest(),
hybridQueryBuilder,
null,
rescorer,
1,
Map.of("search_pipeline", SEARCH_PIPELINE_NAME)
);
Expand All @@ -113,18 +183,18 @@ private void validateTestIndexOnUpgrade(final int numberOfDocs, final String mod
private HybridQueryBuilder getQueryBuilder(
final String modelId,
final Map<String, ?> methodParameters,
final RescoreContext rescoreContext
final RescoreContext rescoreContextForNeuralQuery
) {
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder();
neuralQueryBuilder.fieldName("passage_embedding");
neuralQueryBuilder.fieldName(VECTOR_EMBEDDING_FIELD);
neuralQueryBuilder.modelId(modelId);
neuralQueryBuilder.queryText(QUERY);
neuralQueryBuilder.k(5);
if (methodParameters != null) {
neuralQueryBuilder.methodParameters(methodParameters);
}
if (rescoreContext != null) {
neuralQueryBuilder.rescoreContext(rescoreContext);
if (Objects.nonNull(rescoreContextForNeuralQuery)) {
neuralQueryBuilder.rescoreContext(rescoreContextForNeuralQuery);
}

MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder("text", QUERY);
Expand Down
Loading

0 comments on commit c7ec5a4

Please sign in to comment.