Skip to content

Commit

Permalink
Refactor and Update unit test to include field with no live docs (#2167)
Browse files Browse the repository at this point in the history
Refactored if/else to reduce nesting.
Added unit test when one of the field doesn't have live docs.

Signed-off-by: Vijayan Balasubramanian <[email protected]>
  • Loading branch information
VijayanB authored Sep 30, 2024
1 parent 6f6dd56 commit f16f225
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 29 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Maintenance
* Remove benchmarks folder from k-NN repo [#2127](https://github.com/opensearch-project/k-NN/pull/2127)
### Refactoring
* Minor refactoring and refactored some unit test [#2167](https://github.com/opensearch-project/k-NN/pull/2167)
Original file line number Diff line number Diff line change
Expand Up @@ -84,24 +84,24 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException {
final FieldInfo fieldInfo = field.getFieldInfo();
final VectorDataType vectorDataType = extractVectorDataType(fieldInfo);
int totalLiveDocs = field.getVectors().size();
if (totalLiveDocs > 0) {
final Supplier<KNNVectorValues<?>> knnVectorValuesSupplier = () -> getVectorValues(
vectorDataType,
field.getDocsWithField(),
field.getVectors()
);
final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs);
final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState);
final KNNVectorValues<?> knnVectorValues = knnVectorValuesSupplier.get();

StopWatch stopWatch = new StopWatch().start();
writer.flushIndex(knnVectorValues, totalLiveDocs);
long time_in_millis = stopWatch.stop().totalTime().millis();
KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.incrementBy(time_in_millis);
log.debug("Flush took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName());
} else {
if (totalLiveDocs == 0) {
log.debug("[Flush] No live docs for field {}", fieldInfo.getName());
continue;
}
final Supplier<KNNVectorValues<?>> knnVectorValuesSupplier = () -> getVectorValues(
vectorDataType,
field.getDocsWithField(),
field.getVectors()
);
final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs);
final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState);
final KNNVectorValues<?> knnVectorValues = knnVectorValuesSupplier.get();

StopWatch stopWatch = new StopWatch().start();
writer.flushIndex(knnVectorValues, totalLiveDocs);
long time_in_millis = stopWatch.stop().totalTime().millis();
KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.incrementBy(time_in_millis);
log.debug("Flush took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static com.carrotsearch.randomizedtesting.RandomizedTest.$;
Expand All @@ -44,6 +47,7 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockConstruction;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -86,6 +90,7 @@ public static Collection<Object[]> data() {
"Multi Field",
List.of(
Map.of(0, new float[] { 1, 2, 3 }, 1, new float[] { 2, 3, 4 }, 2, new float[] { 3, 4, 5 }),
Collections.emptyMap(),
Map.of(
0,
new float[] { 1, 2, 3, 4 },
Expand All @@ -105,18 +110,16 @@ public static Collection<Object[]> data() {
@SneakyThrows
public void testFlush() {
// Given
List<KNNVectorValues<float[]>> expectedVectorValues = new ArrayList<>();
IntStream.range(0, vectorsPerField.size()).forEach(i -> {
final List<KNNVectorValues<float[]>> expectedVectorValues = vectorsPerField.stream().map(vectors -> {
final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues(
new ArrayList<>(vectorsPerField.get(i).values())
new ArrayList<>(vectors.values())
);
final KNNVectorValues<float[]> knnVectorValues = KNNVectorValuesFactory.getVectorValues(
VectorDataType.FLOAT,
randomVectorValues
);
expectedVectorValues.add(knnVectorValues);

});
return knnVectorValues;
}).collect(Collectors.toList());

try (
MockedStatic<NativeEngineFieldVectorsWriter> fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class);
Expand Down Expand Up @@ -172,15 +175,19 @@ public void testFlush() {

IntStream.range(0, vectorsPerField.size()).forEach(i -> {
try {
verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size());
if (vectorsPerField.get(i).isEmpty()) {
verify(nativeIndexWriter, never()).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size());
} else {
verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size());
}
} catch (Exception e) {
throw new RuntimeException(e);
}
});

final Long expectedTimesGetVectorValuesIsCalled = vectorsPerField.stream().filter(Predicate.not(Map::isEmpty)).count();
knnVectorValuesFactoryMockedStatic.verify(
() -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()),
times(expectedVectorValues.size())
times(Math.toIntExact(expectedTimesGetVectorValuesIsCalled))
);
}
}
Expand Down Expand Up @@ -264,16 +271,21 @@ public void testFlush_WithQuantization() {

IntStream.range(0, vectorsPerField.size()).forEach(i -> {
try {
verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(i, quantizationState);
verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size());
if (vectorsPerField.get(i).isEmpty()) {
verify(knn990QuantWriterMockedConstruction.constructed().get(0), never()).writeState(i, quantizationState);
verify(nativeIndexWriter, never()).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size());
} else {
verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(i, quantizationState);
verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size());
}
} catch (Exception e) {
throw new RuntimeException(e);
}
});

final Long expectedTimesGetVectorValuesIsCalled = vectorsPerField.stream().filter(Predicate.not(Map::isEmpty)).count();
knnVectorValuesFactoryMockedStatic.verify(
() -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()),
times(expectedVectorValues.size() * 2)
times(Math.toIntExact(expectedTimesGetVectorValuesIsCalled) * 2)
);
}
}
Expand Down

0 comments on commit f16f225

Please sign in to comment.