Skip to content

Commit

Permalink
Fix runtime exceptions in hybrid query for case when sub-query scorer…
Browse files Browse the repository at this point in the history
… return TwoPhase iterator that is incompatible with DISI iterator (#624)

* Adding two phase iterator

Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski authored Mar 11, 2024
1 parent ea49d3c commit c9cdcc1
Show file tree
Hide file tree
Showing 16 changed files with 1,573 additions and 254 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
### Bug Fixes
- Fix runtime exceptions in hybrid query for case when sub-query scorer return TwoPhase iterator that is incompatible with DISI iterator ([#624](https://github.com/opensearch-project/neural-search/pull/624))
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import java.util.List;
import java.util.Objects;

import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
Expand Down Expand Up @@ -77,20 +76,20 @@ public String toString(String field) {
/**
* Re-writes queries into primitive queries. Callers are expected to call rewrite multiple times if necessary,
* until the rewritten query is the same as the original query.
* @param reader
* @param indexSearcher
* @return
* @throws IOException
*/
@Override
public Query rewrite(IndexReader reader) throws IOException {
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
if (subQueries.isEmpty()) {
return new MatchNoDocsQuery("empty HybridQuery");
}

boolean actuallyRewritten = false;
List<Query> rewrittenSubQueries = new ArrayList<>();
for (Query subQuery : subQueries) {
Query rewrittenSub = subQuery.rewrite(reader);
Query rewrittenSub = subQuery.rewrite(indexSearcher);
/* we keep rewrite sub-query unless it's not equal to itself, it may take multiple levels of recursive calls
queries need to be rewritten from high-level clauses into lower-level clauses because low-level clauses
perform better. For hybrid query we need to track progress of re-write for all sub-queries */
Expand All @@ -102,7 +101,7 @@ public Query rewrite(IndexReader reader) throws IOException {
return new HybridQuery(rewrittenSubQueries);
}

return super.rewrite(reader);
return super.rewrite(indexSearcher);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.QueryShardException;
import org.opensearch.index.query.Rewriteable;
import org.opensearch.index.query.QueryBuilderVisitor;

import lombok.Getter;
Expand All @@ -54,7 +53,7 @@ public final class HybridQueryBuilder extends AbstractQueryBuilder<HybridQueryBu

private String fieldName;

private static final int MAX_NUMBER_OF_SUB_QUERIES = 5;
static final int MAX_NUMBER_OF_SUB_QUERIES = 5;

public HybridQueryBuilder(StreamInput in) throws IOException {
super(in);
Expand Down Expand Up @@ -290,7 +289,7 @@ private void writeQueries(StreamOutput out, List<? extends QueryBuilder> queries
private Collection<Query> toQueries(Collection<QueryBuilder> queryBuilders, QueryShardContext context) throws QueryShardException {
List<Query> queries = queryBuilders.stream().map(qb -> {
try {
return Rewriteable.rewrite(qb, context).toQuery(context);
return qb.rewrite(context).toQuery(context);
} catch (IOException e) {
throw new RuntimeException(e);
}
Expand Down
228 changes: 225 additions & 3 deletions src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
Expand All @@ -18,10 +19,13 @@
import org.apache.lucene.search.DisjunctionDISIApproximation;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TwoPhaseIterator;
import org.apache.lucene.search.Weight;

import lombok.Getter;
import org.apache.lucene.util.PriorityQueue;

/**
* Class abstracts functionality of Scorer for hybrid query. When iterating over documents in increasing
Expand All @@ -40,12 +44,56 @@ public final class HybridQueryScorer extends Scorer {

private final Map<Query, List<Integer>> queryToIndex;

public HybridQueryScorer(Weight weight, List<Scorer> subScorers) throws IOException {
private final DocIdSetIterator approximation;
private final HybridScoreBlockBoundaryPropagator disjunctionBlockPropagator;
private final TwoPhase twoPhase;

public HybridQueryScorer(final Weight weight, final List<Scorer> subScorers) throws IOException {
this(weight, subScorers, ScoreMode.TOP_SCORES);
}

HybridQueryScorer(final Weight weight, final List<Scorer> subScorers, final ScoreMode scoreMode) throws IOException {
super(weight);
this.subScorers = Collections.unmodifiableList(subScorers);
subScores = new float[subScorers.size()];
this.queryToIndex = mapQueryToIndex();
this.subScorersPQ = initializeSubScorersPQ();
boolean needsScores = scoreMode != ScoreMode.COMPLETE_NO_SCORES;

this.approximation = new HybridSubqueriesDISIApproximation(this.subScorersPQ);
if (scoreMode == ScoreMode.TOP_SCORES) {
this.disjunctionBlockPropagator = new HybridScoreBlockBoundaryPropagator(subScorers);
} else {
this.disjunctionBlockPropagator = null;
}

boolean hasApproximation = false;
float sumMatchCost = 0;
long sumApproxCost = 0;
// Compute matchCost as the average over the matchCost of the subScorers.
// This is weighted by the cost, which is an expected number of matching documents.
for (DisiWrapper w : subScorersPQ) {
long costWeight = (w.cost <= 1) ? 1 : w.cost;
sumApproxCost += costWeight;
if (w.twoPhaseView != null) {
hasApproximation = true;
sumMatchCost += w.matchCost * costWeight;
}
}
if (!hasApproximation) { // no sub scorer supports approximations
twoPhase = null;
} else {
final float matchCost = sumMatchCost / sumApproxCost;
twoPhase = new TwoPhase(approximation, matchCost, subScorersPQ, needsScores);
}
}

@Override
public int advanceShallow(int target) throws IOException {
if (disjunctionBlockPropagator != null) {
return disjunctionBlockPropagator.advanceShallow(target);
}
return super.advanceShallow(target);
}

/**
Expand All @@ -55,7 +103,10 @@ public HybridQueryScorer(Weight weight, List<Scorer> subScorers) throws IOExcept
*/
@Override
public float score() throws IOException {
DisiWrapper topList = subScorersPQ.topList();
return score(getSubMatches());
}

private float score(DisiWrapper topList) throws IOException {
float totalScore = 0.0f;
for (DisiWrapper disiWrapper = topList; disiWrapper != null; disiWrapper = disiWrapper.next) {
// check if this doc has match in the subQuery. If not, add score as 0.0 and continue
Expand All @@ -67,13 +118,30 @@ public float score() throws IOException {
return totalScore;
}

DisiWrapper getSubMatches() throws IOException {
if (twoPhase == null) {
return subScorersPQ.topList();
} else {
return twoPhase.getSubMatches();
}
}

/**
* Return a DocIdSetIterator over matching documents.
* @return DocIdSetIterator object
*/
@Override
public DocIdSetIterator iterator() {
return new DisjunctionDISIApproximation(this.subScorersPQ);
if (twoPhase != null) {
return TwoPhaseIterator.asDocIdSetIterator(twoPhase);
} else {
return approximation;
}
}

@Override
public TwoPhaseIterator twoPhaseIterator() {
return twoPhase;
}

/**
Expand All @@ -93,12 +161,28 @@ public float getMaxScore(int upTo) throws IOException {
}).max(Float::compare).orElse(0.0f);
}

@Override
public void setMinCompetitiveScore(float minScore) throws IOException {
if (disjunctionBlockPropagator != null) {
disjunctionBlockPropagator.setMinCompetitiveScore(minScore);
}

for (Scorer scorer : subScorers) {
if (Objects.nonNull(scorer)) {
scorer.setMinCompetitiveScore(minScore);
}
}
}

/**
* Returns the doc ID that is currently being scored.
* @return document id
*/
@Override
public int docID() {
if (subScorersPQ.size() == 0) {
return DocIdSetIterator.NO_MORE_DOCS;
}
return subScorersPQ.top().doc;
}

Expand Down Expand Up @@ -169,4 +253,142 @@ private DisiPriorityQueue initializeSubScorersPQ() {
}
return subScorersPQ;
}

@Override
public Collection<ChildScorable> getChildren() throws IOException {
ArrayList<ChildScorable> children = new ArrayList<>();
for (DisiWrapper scorer = getSubMatches(); scorer != null; scorer = scorer.next) {
children.add(new ChildScorable(scorer.scorer, "SHOULD"));
}
return children;
}

/**
* Object returned by {@link Scorer#twoPhaseIterator()} to provide an approximation of a {@link DocIdSetIterator}.
* After calling {@link DocIdSetIterator#nextDoc()} or {@link DocIdSetIterator#advance(int)} on the iterator
* returned by approximation(), you need to check {@link TwoPhaseIterator#matches()} to confirm if the retrieved
* document ID is a match. Implementation inspired by identical class for
* <a href="https://github.com/apache/lucene/blob/branch_9_10/lucene/core/src/java/org/apache/lucene/search/DisjunctionScorer.java">DisjunctionScorer</a>
*/
static class TwoPhase extends TwoPhaseIterator {
private final float matchCost;
// list of verified matches on the current doc
DisiWrapper verifiedMatches;
// priority queue of approximations on the current doc that have not been verified yet
final PriorityQueue<DisiWrapper> unverifiedMatches;
DisiPriorityQueue subScorers;
boolean needsScores;

private TwoPhase(DocIdSetIterator approximation, float matchCost, DisiPriorityQueue subScorers, boolean needsScores) {
super(approximation);
this.matchCost = matchCost;
this.subScorers = subScorers;
unverifiedMatches = new PriorityQueue<>(subScorers.size()) {
@Override
protected boolean lessThan(DisiWrapper a, DisiWrapper b) {
return a.matchCost < b.matchCost;
}
};
this.needsScores = needsScores;
}

DisiWrapper getSubMatches() throws IOException {
for (DisiWrapper wrapper : unverifiedMatches) {
if (wrapper.twoPhaseView.matches()) {
wrapper.next = verifiedMatches;
verifiedMatches = wrapper;
}
}
unverifiedMatches.clear();
return verifiedMatches;
}

@Override
public boolean matches() throws IOException {
verifiedMatches = null;
unverifiedMatches.clear();

for (DisiWrapper wrapper = subScorers.topList(); wrapper != null;) {
DisiWrapper next = wrapper.next;

if (Objects.isNull(wrapper.twoPhaseView)) {
// implicitly verified, move it to verifiedMatches
wrapper.next = verifiedMatches;
verifiedMatches = wrapper;

if (!needsScores) {
// we can stop here
return true;
}
} else {
unverifiedMatches.add(wrapper);
}
wrapper = next;
}

if (Objects.nonNull(verifiedMatches)) {
return true;
}

// verify subs that have an two-phase iterator
// least-costly ones first
while (unverifiedMatches.size() > 0) {
DisiWrapper wrapper = unverifiedMatches.pop();
if (wrapper.twoPhaseView.matches()) {
wrapper.next = null;
verifiedMatches = wrapper;
return true;
}
}
return false;
}

@Override
public float matchCost() {
return matchCost;
}
}

/**
* A DocIdSetIterator which is a disjunction of the approximations of the provided iterators and supports
* sub iterators that return empty results
*/
static class HybridSubqueriesDISIApproximation extends DocIdSetIterator {
final DocIdSetIterator docIdSetIterator;
final DisiPriorityQueue subIterators;

public HybridSubqueriesDISIApproximation(final DisiPriorityQueue subIterators) {
docIdSetIterator = new DisjunctionDISIApproximation(subIterators);
this.subIterators = subIterators;
}

@Override
public long cost() {
return docIdSetIterator.cost();
}

@Override
public int docID() {
if (subIterators.size() == 0) {
return NO_MORE_DOCS;
}
return docIdSetIterator.docID();
}

@Override
public int nextDoc() throws IOException {
if (subIterators.size() == 0) {
return NO_MORE_DOCS;
}
return docIdSetIterator.nextDoc();
}

@Override
public int advance(final int target) throws IOException {
if (subIterators.size() == 0) {
return NO_MORE_DOCS;
}
return docIdSetIterator.advance(target);
}
}
}
Loading

0 comments on commit c9cdcc1

Please sign in to comment.