Skip to content

Commit

Permalink
Move rewrite logic for filter to doRewrite method
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Jul 24, 2024
1 parent b670b11 commit c53fccf
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 6 deletions.
17 changes: 11 additions & 6 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.IndexUtil;
Expand Down Expand Up @@ -485,7 +486,7 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio
}

@Override
protected Query doToQuery(QueryShardContext context) throws IOException {
protected Query doToQuery(QueryShardContext context) {
MappedFieldType mappedFieldType = context.fieldMapper(this.fieldName);

if (mappedFieldType == null && ignoreUnmapped) {
Expand Down Expand Up @@ -600,11 +601,6 @@ protected Query doToQuery(QueryShardContext context) throws IOException {
throw new IllegalArgumentException(String.format(Locale.ROOT, "Engine [%s] does not support filters", knnEngine));
}

// rewrite filter query if it exists to avoid runtime errors in next steps of query phase
if (Objects.nonNull(filter)) {
filter = filter.rewrite(context);
}

String indexName = context.index().getName();

if (k != 0) {
Expand Down Expand Up @@ -715,4 +711,13 @@ protected int doHashCode() {
public String getWriteableName() {
return NAME;
}

@Override
protected QueryBuilder doRewrite(QueryRewriteContext queryShardContext) throws IOException {
// rewrite filter query if it exists to avoid runtime errors in next steps of query phase
if (Objects.nonNull(filter)) {
filter = filter.rewrite(queryShardContext);
}
return super.doRewrite(queryShardContext);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.knn.KNNTestCase;
Expand Down Expand Up @@ -1306,4 +1307,19 @@ public void testDoToQuery_whenBinaryWithInvalidDimension_thenException() throws
Exception ex = expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext));
assertTrue(ex.getMessage(), ex.getMessage().contains("invalid dimension"));
}

@SneakyThrows
public void testDoRewrite_whenNoFilter_thenSuccessful() {
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, K);
QueryBuilder rewritten = knnQueryBuilder.doRewrite(mock(QueryRewriteContext.class));
assertEquals(knnQueryBuilder, rewritten);
}

@SneakyThrows
public void testDoRewrite_whenFilterSet_thenSuccessful() {
QueryBuilder filter = QueryBuilders.termQuery("some_field", "some_value");
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, K, filter);
QueryBuilder rewritten = knnQueryBuilder.doRewrite(mock(QueryRewriteContext.class));
assertEquals(knnQueryBuilder, rewritten);
}
}

0 comments on commit c53fccf

Please sign in to comment.