From 9c855ddf8c10cc37e6580e08e625fe72c6ad8a2e Mon Sep 17 00:00:00 2001 From: Wei Wang <93847013+weiwang118@users.noreply.github.com> Date: Sat, 4 Jan 2025 03:55:18 +0800 Subject: [PATCH] Add check to directly use ANN Search when filters match all docs. (#2320) * Add check to directly use ANN Search when filters match all docs. Signed-off-by: Wei Wang * Fix failed tests and rebase on main branch Signed-off-by: Wei Wang * pass filterbitset as null and add integ tests. Signed-off-by: Wei Wang --------- Signed-off-by: Wei Wang Co-authored-by: Wei Wang --- CHANGELOG.md | 1 + .../knn/index/query/FilterIdsSelector.java | 5 +- .../opensearch/knn/index/query/KNNWeight.java | 10 +- .../knn/index/query/KNNWeightTests.java | 92 ++++++++++++++++++- .../knn/integ/FilteredSearchANNSearchIT.java | 57 ++++++++++++ 5 files changed, 159 insertions(+), 6 deletions(-) create mode 100644 src/test/java/org/opensearch/knn/integ/FilteredSearchANNSearchIT.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 992e6c52b..c148d8b3d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241] - Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290] - Optimizes lucene query execution to prevent unnecessary rewrites (#2305)[https://github.com/opensearch-project/k-NN/pull/2305] +- Add check to directly use ANN Search when filters match all docs. (#2320)[https://github.com/opensearch-project/k-NN/pull/2320] ### Bug Fixes * Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282] * Allow validation for non knn index only after 2.17.0 (#2315)[https://github.com/opensearch-project/k-NN/pull/2315] diff --git a/src/main/java/org/opensearch/knn/index/query/FilterIdsSelector.java b/src/main/java/org/opensearch/knn/index/query/FilterIdsSelector.java index bf06e8c5e..12711911a 100644 --- a/src/main/java/org/opensearch/knn/index/query/FilterIdsSelector.java +++ b/src/main/java/org/opensearch/knn/index/query/FilterIdsSelector.java @@ -78,7 +78,10 @@ public enum FilterIdsSelectorType { public static FilterIdsSelector getFilterIdSelector(final BitSet filterIdsBitSet, final int cardinality) throws IOException { long[] filterIds; FilterIdsSelector.FilterIdsSelectorType filterType; - if (filterIdsBitSet instanceof FixedBitSet) { + if (filterIdsBitSet == null) { + filterIds = null; + filterType = FilterIdsSelector.FilterIdsSelectorType.BITMAP; + } else if (filterIdsBitSet instanceof FixedBitSet) { /** * When filterIds is dense filter, using fixed bitset */ diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 891f9325c..37b5cc9ad 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -129,6 +129,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException { */ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOException { final BitSet filterBitSet = getFilteredDocsBitSet(context); + final int maxDoc = context.reader().maxDoc(); int cardinality = filterBitSet.cardinality(); // We don't need to go to JNI layer if no documents are found which satisfy the filters // We should give this condition a deeper look that where it should be placed. For now I feel this is a good @@ -145,7 +146,14 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep Map result = doExactSearch(context, new BitSetIterator(filterBitSet, cardinality), cardinality, k); return new PerLeafResult(filterWeight == null ? null : filterBitSet, result); } - Map docIdsToScoreMap = doANNSearch(context, filterBitSet, cardinality, k); + + /* + * If filters match all docs in this segment, then null should be passed as filterBitSet + * so that it will not do a bitset look up in bottom search layer. + */ + final BitSet annFilter = (filterWeight != null && cardinality == maxDoc) ? null : filterBitSet; + final Map docIdsToScoreMap = doANNSearch(context, annFilter, cardinality, k); + // See whether we have to perform exact search based on approx search results // This is required if there are no native engine files or if approximate search returned // results less than K, though we have more than k filtered docs diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 511895026..8011cc08c 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -671,7 +671,7 @@ public void validateANNWithFilterQuery_whenDoingANN_thenSuccess(final boolean is when(liveDocsBits.length()).thenReturn(1000); final SegmentReader reader = mockSegmentReader(); - when(reader.maxDoc()).thenReturn(filterDocIds.length); + when(reader.maxDoc()).thenReturn(filterDocIds.length + 1); when(reader.getLiveDocs()).thenReturn(liveDocsBits); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); @@ -758,6 +758,88 @@ public void validateANNWithFilterQuery_whenDoingANN_thenSuccess(final boolean is assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); } + @SneakyThrows + public void testANNWithFilterQuery_whenFiltersMatchAllDocs_thenSuccess() { + // Given + int k = 3; + final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 }; + FixedBitSet filterBitSet = new FixedBitSet(filterDocIds.length); + for (int docId : filterDocIds) { + filterBitSet.set(docId); + } + + jniServiceMockedStatic.when( + () -> JNIService.queryIndex(anyLong(), eq(QUERY_VECTOR), eq(k), eq(HNSW_METHOD_PARAMETERS), any(), eq(null), anyInt(), any()) + ).thenReturn(getFilteredKNNQueryResults()); + + final Bits liveDocsBits = mock(Bits.class); + for (int filterDocId : filterDocIds) { + when(liveDocsBits.get(filterDocId)).thenReturn(true); + } + when(liveDocsBits.length()).thenReturn(1000); + + final SegmentReader reader = mockSegmentReader(); + when(reader.maxDoc()).thenReturn(filterDocIds.length); + when(reader.getLiveDocs()).thenReturn(liveDocsBits); + + final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + when(leafReaderContext.reader()).thenReturn(reader); + + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(QUERY_VECTOR) + .k(k) + .indexName(INDEX_NAME) + .filterQuery(FILTER_QUERY) + .methodParameters(HNSW_METHOD_PARAMETERS) + .build(); + + final Weight filterQueryWeight = mock(Weight.class); + final Scorer filterScorer = mock(Scorer.class); + when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); + // Just to make sure that we are not hitting the exact search condition + when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(filterDocIds.length + 1)); + + final float boost = (float) randomDoubleBetween(0, 10, true); + final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); + + final FieldInfos fieldInfos = mock(FieldInfos.class); + final FieldInfo fieldInfo = mock(FieldInfo.class); + final Map attributesMap = ImmutableMap.of( + KNN_ENGINE, + KNNEngine.FAISS.getName(), + SPACE_TYPE, + SpaceType.L2.getValue() + ); + + when(reader.getFieldInfos()).thenReturn(fieldInfos); + when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); + when(fieldInfo.attributes()).thenReturn(attributesMap); + + // When + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + + // Then + assertNotNull(knnScorer); + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + assertNotNull(docIdSetIterator); + assertEquals(FILTERED_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost()); + + jniServiceMockedStatic.verify( + () -> JNIService.queryIndex(anyLong(), eq(QUERY_VECTOR), eq(k), eq(HNSW_METHOD_PARAMETERS), any(), any(), anyInt(), any()), + times(1) + ); + + final List actualDocIds = new ArrayList<>(); + final Map translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + assertEquals(translatedScores.get(docId) * boost, knnScorer.score(), 0.01f); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + } + private SegmentReader mockSegmentReader() { Path path = mock(Path.class); @@ -815,7 +897,7 @@ public void validateANNWithFilterQuery_whenExactSearch_thenSuccess(final boolean when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); // scorer will return 2 documents when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(1)); - when(reader.maxDoc()).thenReturn(1); + when(reader.maxDoc()).thenReturn(2); final Bits liveDocsBits = mock(Bits.class); when(reader.getLiveDocs()).thenReturn(liveDocsBits); when(liveDocsBits.get(filterDocId)).thenReturn(true); @@ -891,6 +973,7 @@ public void testRadialSearch_whenNoEngineFiles_thenPerformExactSearch() { final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); final SegmentReader reader = mock(SegmentReader.class); when(leafReaderContext.reader()).thenReturn(reader); + when(reader.maxDoc()).thenReturn(1); final FSDirectory directory = mock(FSDirectory.class); when(reader.directory()).thenReturn(directory); @@ -968,7 +1051,7 @@ public void testANNWithFilterQuery_whenExactSearchAndThresholdComputations_thenS when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); // scorer will return 2 documents when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(1)); - when(reader.maxDoc()).thenReturn(1); + when(reader.maxDoc()).thenReturn(2); final Bits liveDocsBits = mock(Bits.class); when(reader.getLiveDocs()).thenReturn(liveDocsBits); when(liveDocsBits.get(filterDocId)).thenReturn(true); @@ -1168,6 +1251,7 @@ public void testANNWithFilterQuery_whenEmptyFilterIds_thenReturnEarly() { final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); final SegmentReader reader = mock(SegmentReader.class); when(leafReaderContext.reader()).thenReturn(reader); + when(reader.maxDoc()).thenReturn(1); final Weight filterQueryWeight = mock(Weight.class); final Scorer filterScorer = mock(Scorer.class); @@ -1202,7 +1286,7 @@ public void testANNWithParentsFilter_whenExactSearch_thenSuccess() { // We will have 0, 1 for filteredIds and 2 will be the parent id for both of them final Scorer filterScorer = mock(Scorer.class); when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(2)); - when(reader.maxDoc()).thenReturn(2); + when(reader.maxDoc()).thenReturn(3); // Query vector is {1.8f, 2.4f}, therefore, second vector {1.9f, 2.5f} should be returned in a result final List vectors = Arrays.asList(new float[] { 0.1f, 0.3f }, new float[] { 1.9f, 2.5f }); diff --git a/src/test/java/org/opensearch/knn/integ/FilteredSearchANNSearchIT.java b/src/test/java/org/opensearch/knn/integ/FilteredSearchANNSearchIT.java new file mode 100644 index 000000000..191ab944c --- /dev/null +++ b/src/test/java/org/opensearch/knn/integ/FilteredSearchANNSearchIT.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.integ; + +import com.google.common.collect.ImmutableMap; +import lombok.SneakyThrows; +import lombok.extern.log4j.Log4j2; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.opensearch.client.Response; +import org.opensearch.common.settings.Settings; +import org.opensearch.knn.KNNJsonQueryBuilder; +import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.index.KNNSettings; +import java.util.List; + +import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; + +@Log4j2 +public class FilteredSearchANNSearchIT extends KNNRestTestCase { + @SneakyThrows + public void testFilteredSearchWithFaissHnsw_whenFiltersMatchAllDocs_thenReturnCorrectResults() { + String filterFieldName = "color"; + final int expectResultSize = randomIntBetween(1, 3); + final String filterValue = "red"; + createKnnIndex(INDEX_NAME, getKNNDefaultIndexSettings(), createKnnIndexMapping(FIELD_NAME, 3, METHOD_HNSW, FAISS_NAME)); + + // ingest 4 vector docs into the index with the same field {"color": "red"} + for (int i = 0; i < 4; i++) { + addKnnDocWithAttributes(String.valueOf(i), new float[] { i, i, i }, ImmutableMap.of(filterFieldName, filterValue)); + } + + refreshIndex(INDEX_NAME); + forceMergeKnnIndex(INDEX_NAME); + + updateIndexSettings(INDEX_NAME, Settings.builder().put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, 0)); + + Float[] queryVector = { 3f, 3f, 3f }; + // All docs in one segment will match the filters value + String query = KNNJsonQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .k(expectResultSize) + .filterFieldName(filterFieldName) + .filterValue(filterValue) + .build() + .getQueryString(); + Response response = searchKNNIndex(INDEX_NAME, query, expectResultSize); + String entity = EntityUtils.toString(response.getEntity()); + List docIds = parseIds(entity); + assertEquals(expectResultSize, docIds.size()); + assertEquals(expectResultSize, parseTotalSearchHits(entity)); + } +}