Skip to content

Commit

Permalink
multiply factor for metadata filter
Browse files Browse the repository at this point in the history
  • Loading branch information
sdan committed Apr 26, 2024
1 parent 35f4bc7 commit 584e6c7
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
1 change: 1 addition & 0 deletions vlite/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from enum import Enum
from typing import List, Union, Dict

# not implemented
class BinaryVectorIndex:
def __init__(self, embedding_size=64):
self.index = {}
Expand Down
13 changes: 10 additions & 3 deletions vlite/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,13 @@ def retrieve(self, text=None, top_k=5, metadata=None, return_scores=False):

def rank_and_filter(self, query_binary_vector, top_k, metadata=None):
start_time = time.time()

# If metadata filter is provided, retrieve more items initially
if metadata:
initial_top_k = top_k * 4 # Adjust this factor as needed
else:
initial_top_k = top_k

logger.debug(f"[VLite.rank_and_filter] Shape of query vector: {query_binary_vector.shape}")
query_binary_vector = np.array(query_binary_vector).reshape(-1)
logger.debug(f"[VLite.rank_and_filter] Shape of query vector after reshaping: {query_binary_vector.shape}")
Expand All @@ -133,9 +140,9 @@ def rank_and_filter(self, query_binary_vector, top_k, metadata=None):
logger.debug(f"[VLite.rank_and_filter] Shape of corpus binary vectors array: {corpus_binary_vectors.shape}")
else:
raise ValueError("No valid binary vectors found for comparison.")
top_k_indices, top_k_scores = self.model.search(query_binary_vector, corpus_binary_vectors, top_k)
logger.debug(f"[VLite.rank_and_filter] Top {top_k} indices: {top_k_indices}")
logger.debug(f"[VLite.rank_and_filter] Top {top_k} scores: {top_k_scores}")
top_k_indices, top_k_scores = self.model.search(query_binary_vector, corpus_binary_vectors, initial_top_k)
logger.debug(f"[VLite.rank_and_filter] Top {initial_top_k} indices: {top_k_indices}")
logger.debug(f"[VLite.rank_and_filter] Top {initial_top_k} scores: {top_k_scores}")
logger.debug(f"[VLite.rank_and_filter] No. of items in the collection: {len(self.index)}")
logger.debug(f"[VLite.rank_and_filter] Vlite count: {self.count()}")

Expand Down

0 comments on commit 584e6c7

Please sign in to comment.