Skip to content

Commit

Permalink
Skipping results of filter if field does not exists
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Jul 24, 2024
1 parent 71fff47 commit 1899460
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio
}

@Override
protected Query doToQuery(QueryShardContext context) {
protected Query doToQuery(QueryShardContext context) throws IOException {
MappedFieldType mappedFieldType = context.fieldMapper(this.fieldName);

if (mappedFieldType == null && ignoreUnmapped) {
Expand Down Expand Up @@ -600,6 +600,11 @@ protected Query doToQuery(QueryShardContext context) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "Engine [%s] does not support filters", knnEngine));
}

// rewrite filter query if it exists
if (Objects.nonNull(filter)) {
filter = filter.rewrite(context);
}

String indexName = context.index().getName();

if (k != 0) {
Expand Down
72 changes: 72 additions & 0 deletions src/test/java/org/opensearch/knn/index/LuceneEngineIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.common.Nullable;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.KNNResult;
Expand Down Expand Up @@ -278,6 +279,77 @@ public void testQueryWithFilterUsingByteVectorDataType() {
validateQueryResultsWithFilters(searchVector, 5, 1, expectedDocIdsKGreaterThanFilterResult, expectedDocIdsKLimitsFilterResult);
}

@SneakyThrows
public void testQueryWithFilter_whenNonExistingFieldUsedInFilter_thenSuccessful() {
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject(PROPERTIES_FIELD_NAME)
.startObject(FIELD_NAME)
.field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE)
.field(DIMENSION_FIELD_NAME, DIMENSION)
.startObject(KNNConstants.KNN_METHOD)
.field(KNNConstants.NAME, KNNEngine.LUCENE.getMethod(METHOD_HNSW).getMethodComponent().getName())
.field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue())
.field(KNNConstants.KNN_ENGINE, KNNEngine.LUCENE.getName())
.endObject()
.endObject()
.startObject("int_field")
.field(TYPE_FIELD_NAME, "integer")
.endObject()
.endObject()
.endObject();
Map<String, Object> mappingMap = xContentBuilderToMap(builder);
String mapping = builder.toString();

createKnnIndex(INDEX_NAME, mapping);

Float[] vector = new Float[] { 2.0f, 4.5f, 6.5f };

String documentAsString = XContentFactory.jsonBuilder()
.startObject()
.field("int_field", 5)
.field(FIELD_NAME, vector)
.endObject()
.toString();

addKnnDoc(INDEX_NAME, DOC_ID, documentAsString);

refreshIndex(INDEX_NAME);
assertEquals(1, getDocCount(INDEX_NAME));

float[] searchVector = new float[] { 1.0f, 2.1f, 3.9f };
int k = 10;

// use filter where non existent field is must, we should have no results
QueryBuilder filterWithRequiredNonExistentField = QueryBuilders.boolQuery()
.must(QueryBuilders.rangeQuery("nonexistent_int_field").gte(1));
Response searchWithRequiredNonExistentFiledInFilterResponse = searchKNNIndex(
INDEX_NAME,
new KNNQueryBuilder(FIELD_NAME, searchVector, k, filterWithRequiredNonExistentField),
k
);
List<KNNResult> resultsQuery1 = parseSearchResponse(
EntityUtils.toString(searchWithRequiredNonExistentFiledInFilterResponse.getEntity()),
FIELD_NAME
);
assertTrue(resultsQuery1.isEmpty());

// use filter with non existent field as optional, we should have some results
QueryBuilder filterWithOptionalNonExistentField = QueryBuilders.boolQuery()
.should(QueryBuilders.rangeQuery("nonexistent_int_field").gte(1))
.must(QueryBuilders.rangeQuery("int_field").gte(1));
Response searchWithOptionalNonExistentFiledInFilterResponse = searchKNNIndex(
INDEX_NAME,
new KNNQueryBuilder(FIELD_NAME, searchVector, k, filterWithOptionalNonExistentField),
k
);
List<KNNResult> resultsQuery2 = parseSearchResponse(
EntityUtils.toString(searchWithOptionalNonExistentFiledInFilterResponse.getEntity()),
FIELD_NAME
);
assertEquals(1, resultsQuery2.size());
}

public void testQuery_filterWithNonLuceneEngine() throws Exception {
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.knn.index.query;

import com.google.common.collect.ImmutableMap;
import lombok.SneakyThrows;
import org.apache.lucene.search.FloatVectorSimilarityQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
Expand Down Expand Up @@ -485,6 +486,7 @@ public void testDoToQuery_Normal() throws Exception {
assertEquals(knnQueryBuilder.vector(), query.getQueryVector());
}

@SneakyThrows
public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenDistanceThreshold_thenSucceed() {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder()
Expand Down Expand Up @@ -518,6 +520,7 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenDistanceThreshold_th
);
}

@SneakyThrows
public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenScoreThreshold_thenSucceed() {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };

Expand All @@ -540,6 +543,7 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenScoreThreshold_thenS
assertTrue(query.toString().contains("resultSimilarity=" + 0.5f));
}

@SneakyThrows
public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenSupportedSpaceType_thenSucceed() {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
float negativeDistance = -1.0f;
Expand Down Expand Up @@ -602,6 +606,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenUnSupp
expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext));
}

@SneakyThrows
public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenSupportedSpaceType_thenSucceed() {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
float score = 5f;
Expand Down Expand Up @@ -655,6 +660,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenUnsupp
expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext));
}

@SneakyThrows
public void testDoToQuery_whenPassNegativeDistance_whenSupportedSpaceType_thenSucceed() {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
float negativeDistance = -1.0f;
Expand Down Expand Up @@ -774,6 +780,7 @@ public void testDoToQuery_KnnQueryWithFilter_Lucene() throws Exception {
assertTrue(query.getClass().isAssignableFrom(KnnFloatVectorQuery.class));
}

@SneakyThrows
public void testDoToQuery_whenDoRadiusSearch_whenDistanceThreshold_whenFilter_thenSucceed() {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };

Expand Down Expand Up @@ -802,6 +809,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenDistanceThreshold_whenFilter_th
assertTrue(query.getClass().isAssignableFrom(FloatVectorSimilarityQuery.class));
}

@SneakyThrows
public void testDoToQuery_whenDoRadiusSearch_whenScoreThreshold_whenFilter_thenSucceed() {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder()
Expand All @@ -828,6 +836,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenScoreThreshold_whenFilter_thenS
assertTrue(query.getClass().isAssignableFrom(FloatVectorSimilarityQuery.class));
}

@SneakyThrows
public void testDoToQuery_WhenknnQueryWithFilterAndFaissEngine_thenSuccess() {
// Given
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
Expand Down Expand Up @@ -904,6 +913,7 @@ public void testDoToQuery_whenknnQueryWithFilterAndNmsLibEngine_thenException()
expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext));
}

@SneakyThrows
public void testDoToQuery_FromModel() {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K);
Expand Down Expand Up @@ -938,6 +948,7 @@ public void testDoToQuery_FromModel() {
assertEquals(knnQueryBuilder.vector(), query.getQueryVector());
}

@SneakyThrows
public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold_thenSucceed() {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };

Expand Down Expand Up @@ -979,6 +990,7 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold
assertEquals(knnQueryBuilder.vector(), query.getQueryVector());
}

@SneakyThrows
public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenScoreThreshold_thenSucceed() {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };

Expand Down Expand Up @@ -1233,6 +1245,7 @@ public void testRadialSearch_whenEfSearchIsSet_whenLuceneEngine_thenThrowExcepti
expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext));
}

@SneakyThrows
public void testRadialSearch_whenEfSearchIsSet_whenFaissEngine_thenSuccess() {
KNNMethodContext knnMethodContext = new KNNMethodContext(
KNNEngine.FAISS,
Expand Down

0 comments on commit 1899460

Please sign in to comment.