Skip to content

Commit

Permalink
Fix failed tests and rebase on main branch
Browse files Browse the repository at this point in the history
Signed-off-by: Wei Wang <[email protected]>
  • Loading branch information
weiwang118 committed Dec 19, 2024
1 parent a501c1a commit 105c39a
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
- Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290]
- Optimizes lucene query execution to prevent unnecessary rewrites (#2305)[https://github.com/opensearch-project/k-NN/pull/2305]
- Add check to directly use ANN Search when filters match all docs. (#2320)[https://github.com/opensearch-project/k-NN/pull/2320]
### Bug Fixes
* Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282]
* Allow validation for non knn index only after 2.17.0 (#2315)[https://github.com/opensearch-project/k-NN/pull/2315]
Expand Down
8 changes: 1 addition & 7 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,6 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep
final BitSet filterBitSet = getFilteredDocsBitSet(context);
final int maxDoc = context.reader().maxDoc();
int cardinality = filterBitSet.cardinality();
/*
* If filters match all docs in this segment, then there is no need to do any extra step
* and should directly do ANN Search*/
if (cardinality == maxDoc){
return doANNSearch(context, filterBitSet, cardinality, k);
}
// We don't need to go to JNI layer if no documents are found which satisfy the filters
// We should give this condition a deeper look that where it should be placed. For now I feel this is a good
// place,
Expand All @@ -156,7 +150,7 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep
* If filters match all docs in this segment, then there is no need to do any extra step
* and should directly do ANN Search*/
if (filterWeight != null && cardinality == maxDoc) {
return doANNSearch(context, new FixedBitSet(0), 0, k);
return new PerLeafResult(new FixedBitSet(0), doANNSearch(context, new FixedBitSet(0), 0, k));
}
Map<Integer, Float> docIdsToScoreMap = doANNSearch(context, filterBitSet, cardinality, k);
// See whether we have to perform exact search based on approx search results
Expand Down
93 changes: 92 additions & 1 deletion src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ public void validateANNWithFilterQuery_whenDoingANN_thenSuccess(final boolean is
when(liveDocsBits.length()).thenReturn(1000);

final SegmentReader reader = mockSegmentReader();
when(reader.maxDoc()).thenReturn(filterDocIds.length);
when(reader.maxDoc()).thenReturn(filterDocIds.length + 1);
when(reader.getLiveDocs()).thenReturn(liveDocsBits);

final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
Expand Down Expand Up @@ -758,6 +758,97 @@ public void validateANNWithFilterQuery_whenDoingANN_thenSuccess(final boolean is
assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder()));
}

@SneakyThrows
public void testANNWithFilterQuery_whenFiltersMatchAllDocs_thenSuccess() {
// Given
int k = 3;
final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 };
FixedBitSet filterBitSet = new FixedBitSet(filterDocIds.length);
for (int docId : filterDocIds) {
filterBitSet.set(docId);
}

jniServiceMockedStatic.when(
() -> JNIService.queryIndex(
anyLong(),
eq(QUERY_VECTOR),
eq(k),
eq(HNSW_METHOD_PARAMETERS),
any(),
eq(new FixedBitSet(0).getBits()),
anyInt(),
any()
)
).thenReturn(getFilteredKNNQueryResults());

final Bits liveDocsBits = mock(Bits.class);
for (int filterDocId : filterDocIds) {
when(liveDocsBits.get(filterDocId)).thenReturn(true);
}
when(liveDocsBits.length()).thenReturn(1000);

final SegmentReader reader = mockSegmentReader();
when(reader.maxDoc()).thenReturn(filterDocIds.length);
when(reader.getLiveDocs()).thenReturn(liveDocsBits);

final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
when(leafReaderContext.reader()).thenReturn(reader);

final KNNQuery query = KNNQuery.builder()
.field(FIELD_NAME)
.queryVector(QUERY_VECTOR)
.k(k)
.indexName(INDEX_NAME)
.filterQuery(FILTER_QUERY)
.methodParameters(HNSW_METHOD_PARAMETERS)
.build();

final Weight filterQueryWeight = mock(Weight.class);
final Scorer filterScorer = mock(Scorer.class);
when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer);
// Just to make sure that we are not hitting the exact search condition
when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(filterDocIds.length + 1));

final float boost = (float) randomDoubleBetween(0, 10, true);
final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight);

final FieldInfos fieldInfos = mock(FieldInfos.class);
final FieldInfo fieldInfo = mock(FieldInfo.class);
final Map<String, String> attributesMap = ImmutableMap.of(
KNN_ENGINE,
KNNEngine.FAISS.getName(),
SPACE_TYPE,
SpaceType.L2.getValue()
);

when(reader.getFieldInfos()).thenReturn(fieldInfos);
when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo);
when(fieldInfo.attributes()).thenReturn(attributesMap);

// When
final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext);

// Then
assertNotNull(knnScorer);
final DocIdSetIterator docIdSetIterator = knnScorer.iterator();
assertNotNull(docIdSetIterator);
assertEquals(FILTERED_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost());

jniServiceMockedStatic.verify(
() -> JNIService.queryIndex(anyLong(), eq(QUERY_VECTOR), eq(k), eq(HNSW_METHOD_PARAMETERS), any(), any(), anyInt(), any()),
times(1)
);

final List<Integer> actualDocIds = new ArrayList<>();
final Map<Integer, Float> translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation);
for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) {
actualDocIds.add(docId);
assertEquals(translatedScores.get(docId) * boost, knnScorer.score(), 0.01f);
}
assertEquals(docIdSetIterator.cost(), actualDocIds.size());
assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder()));
}

private SegmentReader mockSegmentReader() {
Path path = mock(Path.class);

Expand Down

0 comments on commit 105c39a

Please sign in to comment.