Skip to content

Commit

Permalink
Add b1 support back, but with ubinary
Browse files Browse the repository at this point in the history
  • Loading branch information
tomaarsen committed Oct 15, 2024
1 parent 95a0acb commit 5ac0354
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# 4. Choose a target precision for the corpus embeddings
corpus_precision = "binary"
# Valid options are: "float32", "uint8", "int8", "ubinary", and "binary"
# But usearch only supports "float32", "int8", and "binary"
# But usearch only supports "float32", "int8", "binary" and "ubinary"

# 5. Encode the corpus
full_corpus_embeddings = model.encode(corpus, normalize_embeddings=True, show_progress_bar=True)
Expand Down
16 changes: 11 additions & 5 deletions sentence_transformers/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@ def semantic_search_usearch(
`corpus_embeddings` or `corpus_index` should be used, not
both.
corpus_precision: Precision of the corpus embeddings. The
options are "float32", "int8", or "binary". Default is
"float32".
options are "float32", "int8", "ubinary" or "binary". Default
is "float32".
top_k: Number of top results to retrieve. Default is 10.
ranges: Ranges for quantization of embeddings. This is only used
for int8 quantization, where the ranges refers to the
Expand Down Expand Up @@ -263,8 +263,8 @@ def semantic_search_usearch(
raise ValueError("Only corpus_embeddings or corpus_index should be used, not both.")
if corpus_embeddings is None and corpus_index is None:
raise ValueError("Either corpus_embeddings or corpus_index should be used.")
if corpus_precision not in ["float32", "int8", "binary"]:
raise ValueError('corpus_precision must be "float32", "int8", or "binary" for usearch')
if corpus_precision not in ["float32", "int8", "ubinary", "binary"]:
raise ValueError('corpus_precision must be "float32", "int8", "ubinary", "binary" for usearch')

# If corpus_index is not provided, create a new index
if corpus_index is None:
Expand All @@ -286,6 +286,12 @@ def semantic_search_usearch(
metric="hamming",
dtype="i8",
)
elif corpus_precision == "ubinary":
corpus_index = Index(
ndim=corpus_embeddings.shape[1] * 8,
metric="hamming",
dtype="b1",
)
corpus_index.add(np.arange(len(corpus_embeddings)), corpus_embeddings)

# If rescoring is enabled and the query embeddings are in float32, we need to quantize them
Expand Down Expand Up @@ -331,7 +337,7 @@ def semantic_search_usearch(
if rescore_embeddings is not None:
top_k_embeddings = np.array([corpus_index.get(query_indices) for query_indices in indices])
# If the corpus precision is binary, we need to unpack the bits
if corpus_precision == "binary":
if corpus_precision in ("ubinary", "binary"):
top_k_embeddings = np.unpackbits(top_k_embeddings.astype(np.uint8), axis=-1)
top_k_embeddings = top_k_embeddings.astype(int)

Expand Down

0 comments on commit 5ac0354

Please sign in to comment.