From 69e7569e1c94cc035520b40befdbddd19d2791e3 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Wed, 30 Aug 2023 17:40:05 -0700 Subject: [PATCH] boost added --- backend/danswer/chunking/models.py | 3 +++ backend/danswer/search/semantic_search.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/backend/danswer/chunking/models.py b/backend/danswer/chunking/models.py index 0ead152b779..f0b8d8ac729 100644 --- a/backend/danswer/chunking/models.py +++ b/backend/danswer/chunking/models.py @@ -5,6 +5,7 @@ from typing import cast from danswer.configs.constants import BLURB +from danswer.configs.constants import BOOST from danswer.configs.constants import METADATA from danswer.configs.constants import SEMANTIC_IDENTIFIER from danswer.configs.constants import SOURCE_LINKS @@ -57,6 +58,7 @@ class InferenceChunk(BaseChunk): document_id: str source_type: str semantic_identifier: str + boost: float metadata: dict[str, Any] @classmethod @@ -78,6 +80,7 @@ def from_dict(cls, init_dict: dict[str, Any]) -> "InferenceChunk": init_kwargs[METADATA] = json.loads(init_kwargs[METADATA]) else: init_kwargs[METADATA] = {} + init_kwargs[BOOST] = init_kwargs.get(BOOST, 1) if init_kwargs.get(SEMANTIC_IDENTIFIER) is None: logger.error( f"Chunk with blurb: {init_kwargs.get(BLURB, 'Unknown')[:50]}... has no Semantic Identifier" diff --git a/backend/danswer/search/semantic_search.py b/backend/danswer/search/semantic_search.py index c6122b9b473..86096495c12 100644 --- a/backend/danswer/search/semantic_search.py +++ b/backend/danswer/search/semantic_search.py @@ -57,7 +57,8 @@ def semantic_reranking( encoder.predict([(query, chunk.content) for chunk in chunks]) # type: ignore for encoder in cross_encoders ] - averaged_sim_scores = sum(sim_scores) / len(sim_scores) + boosts = [chunk.boost for chunk in chunks] + averaged_sim_scores = sum(sim_scores) * boosts / len(sim_scores) scored_results = list(zip(averaged_sim_scores, chunks)) scored_results.sort(key=lambda x: x[0], reverse=True) ranked_sim_scores, ranked_chunks = zip(*scored_results)