From 7f05a2b20a251c7a701866d0c68462675dfbf908 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Tue, 23 Jul 2024 16:35:05 -0700 Subject: [PATCH] Return empty results for non-existent filter fields Signed-off-by: Martin Gaievski --- .../knn/index/query/KNNQueryBuilder.java | 7 +- .../opensearch/knn/index/LuceneEngineIT.java | 72 +++++++++++++++++++ .../knn/index/query/KNNQueryBuilderTests.java | 13 ++++ 3 files changed, 91 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 664a4de3e..0dc27f04a 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -485,7 +485,7 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio } @Override - protected Query doToQuery(QueryShardContext context) { + protected Query doToQuery(QueryShardContext context) throws IOException { MappedFieldType mappedFieldType = context.fieldMapper(this.fieldName); if (mappedFieldType == null && ignoreUnmapped) { @@ -600,6 +600,11 @@ protected Query doToQuery(QueryShardContext context) { throw new IllegalArgumentException(String.format(Locale.ROOT, "Engine [%s] does not support filters", knnEngine)); } + // rewrite filter query if it exists to avoid runtime errors in next steps of query phase + if (Objects.nonNull(filter)) { + filter = filter.rewrite(context); + } + String indexName = context.index().getName(); if (k != 0) { diff --git a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java index e7f38787d..52db96e2e 100644 --- a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java +++ b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java @@ -17,6 +17,7 @@ import org.opensearch.common.Nullable; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.KNNResult; @@ -278,6 +279,77 @@ public void testQueryWithFilterUsingByteVectorDataType() { validateQueryResultsWithFilters(searchVector, 5, 1, expectedDocIdsKGreaterThanFilterResult, expectedDocIdsKLimitsFilterResult); } + @SneakyThrows + public void testQueryWithFilter_whenNonExistingFieldUsedInFilter_thenSuccessful() { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD_NAME) + .startObject(FIELD_NAME) + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, DIMENSION) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, KNNEngine.LUCENE.getMethod(METHOD_HNSW).getMethodComponent().getName()) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue()) + .field(KNNConstants.KNN_ENGINE, KNNEngine.LUCENE.getName()) + .endObject() + .endObject() + .startObject("int_field") + .field(TYPE_FIELD_NAME, "integer") + .endObject() + .endObject() + .endObject(); + Map mappingMap = xContentBuilderToMap(builder); + String mapping = builder.toString(); + + createKnnIndex(INDEX_NAME, mapping); + + Float[] vector = new Float[] { 2.0f, 4.5f, 6.5f }; + + String documentAsString = XContentFactory.jsonBuilder() + .startObject() + .field("int_field", 5) + .field(FIELD_NAME, vector) + .endObject() + .toString(); + + addKnnDoc(INDEX_NAME, DOC_ID, documentAsString); + + refreshIndex(INDEX_NAME); + assertEquals(1, getDocCount(INDEX_NAME)); + + float[] searchVector = new float[] { 1.0f, 2.1f, 3.9f }; + int k = 10; + + // use filter where non existent field is must, we should have no results + QueryBuilder filterWithRequiredNonExistentField = QueryBuilders.boolQuery() + .must(QueryBuilders.rangeQuery("nonexistent_int_field").gte(1)); + Response searchWithRequiredNonExistentFiledInFilterResponse = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder(FIELD_NAME, searchVector, k, filterWithRequiredNonExistentField), + k + ); + List resultsQuery1 = parseSearchResponse( + EntityUtils.toString(searchWithRequiredNonExistentFiledInFilterResponse.getEntity()), + FIELD_NAME + ); + assertTrue(resultsQuery1.isEmpty()); + + // use filter with non existent field as optional, we should have some results + QueryBuilder filterWithOptionalNonExistentField = QueryBuilders.boolQuery() + .should(QueryBuilders.rangeQuery("nonexistent_int_field").gte(1)) + .must(QueryBuilders.rangeQuery("int_field").gte(1)); + Response searchWithOptionalNonExistentFiledInFilterResponse = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder(FIELD_NAME, searchVector, k, filterWithOptionalNonExistentField), + k + ); + List resultsQuery2 = parseSearchResponse( + EntityUtils.toString(searchWithOptionalNonExistentFiledInFilterResponse.getEntity()), + FIELD_NAME + ); + assertEquals(1, resultsQuery2.size()); + } + public void testQuery_filterWithNonLuceneEngine() throws Exception { XContentBuilder builder = XContentFactory.jsonBuilder() .startObject() diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index b28739655..070ea0fbb 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -6,6 +6,7 @@ package org.opensearch.knn.index.query; import com.google.common.collect.ImmutableMap; +import lombok.SneakyThrows; import org.apache.lucene.search.FloatVectorSimilarityQuery; import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.MatchNoDocsQuery; @@ -485,6 +486,7 @@ public void testDoToQuery_Normal() throws Exception { assertEquals(knnQueryBuilder.vector(), query.getQueryVector()); } + @SneakyThrows public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenDistanceThreshold_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() @@ -518,6 +520,7 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenDistanceThreshold_th ); } + @SneakyThrows public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenScoreThreshold_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; @@ -540,6 +543,7 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenScoreThreshold_thenS assertTrue(query.toString().contains("resultSimilarity=" + 0.5f)); } + @SneakyThrows public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenSupportedSpaceType_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; float negativeDistance = -1.0f; @@ -602,6 +606,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenUnSupp expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); } + @SneakyThrows public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenSupportedSpaceType_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; float score = 5f; @@ -655,6 +660,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenUnsupp expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); } + @SneakyThrows public void testDoToQuery_whenPassNegativeDistance_whenSupportedSpaceType_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; float negativeDistance = -1.0f; @@ -774,6 +780,7 @@ public void testDoToQuery_KnnQueryWithFilter_Lucene() throws Exception { assertTrue(query.getClass().isAssignableFrom(KnnFloatVectorQuery.class)); } + @SneakyThrows public void testDoToQuery_whenDoRadiusSearch_whenDistanceThreshold_whenFilter_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; @@ -802,6 +809,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenDistanceThreshold_whenFilter_th assertTrue(query.getClass().isAssignableFrom(FloatVectorSimilarityQuery.class)); } + @SneakyThrows public void testDoToQuery_whenDoRadiusSearch_whenScoreThreshold_whenFilter_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() @@ -828,6 +836,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenScoreThreshold_whenFilter_thenS assertTrue(query.getClass().isAssignableFrom(FloatVectorSimilarityQuery.class)); } + @SneakyThrows public void testDoToQuery_WhenknnQueryWithFilterAndFaissEngine_thenSuccess() { // Given float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; @@ -904,6 +913,7 @@ public void testDoToQuery_whenknnQueryWithFilterAndNmsLibEngine_thenException() expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); } + @SneakyThrows public void testDoToQuery_FromModel() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); @@ -938,6 +948,7 @@ public void testDoToQuery_FromModel() { assertEquals(knnQueryBuilder.vector(), query.getQueryVector()); } + @SneakyThrows public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; @@ -979,6 +990,7 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold assertEquals(knnQueryBuilder.vector(), query.getQueryVector()); } + @SneakyThrows public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenScoreThreshold_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; @@ -1233,6 +1245,7 @@ public void testRadialSearch_whenEfSearchIsSet_whenLuceneEngine_thenThrowExcepti expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); } + @SneakyThrows public void testRadialSearch_whenEfSearchIsSet_whenFaissEngine_thenSuccess() { KNNMethodContext knnMethodContext = new KNNMethodContext( KNNEngine.FAISS,