Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Learn from Feedback #343

Merged
merged 14 commits into from
Aug 28, 2023
93 changes: 93 additions & 0 deletions backend/alembic/versions/d929f0c1c6af_feedback_feature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""Feedback Feature

Revision ID: d929f0c1c6af
Revises: 8aabb57f3b49
Create Date: 2023-08-27 13:03:54.274987

"""
import fastapi_users_db_sqlalchemy
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "d929f0c1c6af"
down_revision = "8aabb57f3b49"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.create_table(
"query_event",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("query", sa.String(), nullable=False),
sa.Column(
"selected_search_flow",
sa.Enum("KEYWORD", "SEMANTIC", name="searchtype"),
nullable=True,
),
sa.Column("llm_answer", sa.String(), nullable=True),
sa.Column(
"feedback",
sa.Enum("LIKE", "DISLIKE", name="qafeedbacktype"),
nullable=True,
),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"document_retrieval_feedback",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("qa_event_id", sa.Integer(), nullable=False),
sa.Column("document_id", sa.String(), nullable=False),
sa.Column("document_rank", sa.Integer(), nullable=False),
sa.Column("clicked", sa.Boolean(), nullable=False),
sa.Column(
"feedback",
sa.Enum(
"ENDORSE",
"REJECT",
"HIDE",
"UNHIDE",
name="searchfeedbacktype",
),
nullable=True,
),
sa.ForeignKeyConstraint(
["document_id"],
["document.id"],
),
sa.ForeignKeyConstraint(
["qa_event_id"],
["query_event.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.add_column("document", sa.Column("boost", sa.Integer(), nullable=False))
op.add_column("document", sa.Column("hidden", sa.Boolean(), nullable=False))
op.add_column("document", sa.Column("semantic_id", sa.String(), nullable=False))
op.add_column("document", sa.Column("link", sa.String(), nullable=True))


def downgrade() -> None:
op.drop_column("document", "link")
op.drop_column("document", "semantic_id")
op.drop_column("document", "hidden")
op.drop_column("document", "boost")
op.drop_table("document_retrieval_feedback")
op.drop_table("query_event")
2 changes: 1 addition & 1 deletion backend/danswer/background/connector_deletion.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _update_multi_indexed_docs() -> None:
def _get_user(
credential: Credential,
) -> str:
if credential.public_doc:
if credential.public_doc or not credential.user:
return PUBLIC_DOC_PAT

return str(credential.user.id)
Expand Down
2 changes: 2 additions & 0 deletions backend/danswer/background/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from danswer.db.credentials import backend_update_credential_json
from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.feedback import create_document_metadata
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_inprogress_index_attempts
Expand Down Expand Up @@ -246,6 +247,7 @@ def _index(
logger.debug(
f"Indexing batch of documents: {[doc.to_short_descriptor() for doc in doc_batch]}"
)

index_user_id = (
None if db_credential.public_doc else db_credential.user_id
)
Expand Down
13 changes: 13 additions & 0 deletions backend/danswer/configs/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
PUBLIC_DOC_PAT = "PUBLIC"
QUOTE = "quote"
BOOST = "boost"
DEFAULT_BOOST = 0


class DocumentSource(str, Enum):
Expand Down Expand Up @@ -66,3 +67,15 @@ class ModelHostType(str, Enum):
# https://medium.com/@yuhongsun96/host-a-llama-2-api-on-gpu-for-free-a5311463c183
COLAB_DEMO = "colab-demo"
# TODO support for Azure, AWS, GCP GenAI model hosting


class QAFeedbackType(str, Enum):
LIKE = "like" # User likes the answer, used for metrics
DISLIKE = "dislike" # User dislikes the answer, used for metrics


class SearchFeedbackType(str, Enum):
ENDORSE = "endorse" # boost this document for all future queries
REJECT = "reject" # down-boost this document for all future queries
HIDE = "hide" # mark this document as untrusted, hide from LLM
UNHIDE = "unhide"
9 changes: 9 additions & 0 deletions backend/danswer/datastores/datastore_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@


DEFAULT_BATCH_SIZE = 30
BOOST_MULTIPLIER = 1.2


def translate_boost_count_to_multiplier(boost: int) -> float:
if boost > 0:
return BOOST_MULTIPLIER**boost
elif boost < 0:
return 1 / (BOOST_MULTIPLIER**boost)
return 1


def get_uuid_from_chunk(
Expand Down
18 changes: 16 additions & 2 deletions backend/danswer/datastores/indexing_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __call__(
def _upsert_insertion_records(
insertion_records: set[DocumentInsertionRecord],
index_attempt_metadata: IndexAttemptMetadata,
doc_m_data_lookup: dict[str, tuple[str, str]],
) -> None:
with Session(get_sqlalchemy_engine()) as session:
upsert_documents_complete(
Expand All @@ -40,9 +41,11 @@ def _upsert_insertion_records(
DocumentMetadata(
connector_id=index_attempt_metadata.connector_id,
credential_id=index_attempt_metadata.credential_id,
document_id=insertion_record.document_id,
document_id=i_r.document_id,
semantic_identifier=doc_m_data_lookup[i_r.document_id][0],
first_link=doc_m_data_lookup[i_r.document_id][1],
)
for insertion_record in insertion_records
for i_r in insertion_records
],
)

Expand All @@ -62,6 +65,11 @@ def _get_net_new_documents(
return net_new_documents


def _extract_minimal_document_metadata(doc: Document) -> tuple[str, str]:
first_link = next((section.link for section in doc.sections if section.link), "")
return doc.semantic_identifier, first_link


def _indexing_pipeline(
*,
chunker: Chunker,
Expand All @@ -73,6 +81,11 @@ def _indexing_pipeline(
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
Note that the documents should already be batched at this point so that it does not inflate the
memory requirements"""

document_metadata_lookup = {
doc.id: _extract_minimal_document_metadata(doc) for doc in documents
}

chunks: list[DocAwareChunk] = list(
chain(*[chunker.chunk(document=document) for document in documents])
)
Expand All @@ -92,6 +105,7 @@ def _indexing_pipeline(
_upsert_insertion_records(
insertion_records=insertion_records,
index_attempt_metadata=index_attempt_metadata,
doc_m_data_lookup=document_metadata_lookup,
)
except Exception as e:
logger.error(
Expand Down
4 changes: 3 additions & 1 deletion backend/danswer/datastores/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class DocumentMetadata:
connector_id: int
credential_id: int
document_id: str
semantic_identifier: str
first_link: str


@dataclass
Expand All @@ -32,7 +34,7 @@ class UpdateRequest:
document_ids: list[str]
# all other fields will be left alone
allowed_users: list[str] | None = None
boost: int | None = None
boost: float | None = None


class Verifiable(abc.ABC):
Expand Down
12 changes: 8 additions & 4 deletions backend/danswer/datastores/vespa/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,16 +342,20 @@ def update(self, update_requests: list[UpdateRequest]) -> None:
logger.error("Update request received but nothing to update")
continue

update_dict: dict[str, dict[str, list[str] | int]] = {"fields": {}}
update_dict: dict[str, dict] = {"fields": {}}
if update_request.boost:
update_dict["fields"][BOOST] = update_request.boost
update_dict["fields"][BOOST] = {"assign": update_request.boost}
if update_request.allowed_users:
update_dict["fields"][ALLOWED_USERS] = update_request.allowed_users
update_dict["fields"][ALLOWED_USERS] = {
"assign": update_request.allowed_users
}

for document_id in update_request.document_ids:
for doc_chunk_id in _get_vespa_chunk_ids_by_document_id(document_id):
url = f"{DOCUMENT_ID_ENDPOINT}/{doc_chunk_id}"
res = requests.put(url, headers=json_header, json=update_dict)
res = requests.put(
url, headers=json_header, data=json.dumps(update_dict)
)

try:
res.raise_for_status()
Expand Down
1 change: 0 additions & 1 deletion backend/danswer/db/connector_credential_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from danswer.db.connector import fetch_connector_by_id
from danswer.db.credentials import fetch_credential_by_id
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.db.models import User
from danswer.server.models import StatusResponse
Expand Down
50 changes: 32 additions & 18 deletions backend/danswer/db/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import Session

from danswer.configs.constants import DEFAULT_BOOST
from danswer.datastores.interfaces import DocumentMetadata
from danswer.db.models import Document
from danswer.db.models import Document as DbDocument
from danswer.db.models import DocumentByConnectorCredentialPair
from danswer.db.utils import model_to_dict
from danswer.utils.logger import setup_logger
Expand All @@ -20,7 +21,7 @@ def get_documents_with_single_connector_credential_pair(
db_session: Session,
connector_id: int,
credential_id: int,
) -> Sequence[Document]:
) -> Sequence[DbDocument]:
initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
Expand All @@ -31,17 +32,17 @@ def get_documents_with_single_connector_credential_pair(
# Filter it down to the documents with only a single connector/credential pair
# Meaning if this connector/credential pair is removed, this doc should be gone
trimmed_doc_ids_stmt = (
select(Document.id)
select(DbDocument.id)
.join(
DocumentByConnectorCredentialPair,
DocumentByConnectorCredentialPair.id == Document.id,
DocumentByConnectorCredentialPair.id == DbDocument.id,
)
.where(Document.id.in_(initial_doc_ids_stmt))
.group_by(Document.id)
.where(DbDocument.id.in_(initial_doc_ids_stmt))
.group_by(DbDocument.id)
.having(func.count(DocumentByConnectorCredentialPair.id) == 1)
)

stmt = select(Document).where(Document.id.in_(trimmed_doc_ids_stmt))
stmt = select(DbDocument).where(DbDocument.id.in_(trimmed_doc_ids_stmt))
return db_session.scalars(stmt).all()


Expand All @@ -60,13 +61,13 @@ def get_document_by_connector_credential_pairs_indexed_by_multiple(
# Filter it down to the documents with more than 1 connector/credential pair
# Meaning if this connector/credential pair is removed, this doc is still accessible
trimmed_doc_ids_stmt = (
select(Document.id)
select(DbDocument.id)
.join(
DocumentByConnectorCredentialPair,
DocumentByConnectorCredentialPair.id == Document.id,
DocumentByConnectorCredentialPair.id == DbDocument.id,
)
.where(Document.id.in_(initial_doc_ids_stmt))
.group_by(Document.id)
.where(DbDocument.id.in_(initial_doc_ids_stmt))
.group_by(DbDocument.id)
.having(func.count(DocumentByConnectorCredentialPair.id) > 1)
)

Expand All @@ -81,13 +82,25 @@ def upsert_documents(
db_session: Session, document_metadata_batch: list[DocumentMetadata]
) -> None:
"""NOTE: this function is Postgres specific. Not all DBs support the ON CONFLICT clause."""
seen_document_ids: set[str] = set()
seen_documents: dict[str, DocumentMetadata] = {}
for document_metadata in document_metadata_batch:
if document_metadata.document_id not in seen_document_ids:
seen_document_ids.add(document_metadata.document_id)
doc_id = document_metadata.document_id
if doc_id not in seen_documents:
seen_documents[doc_id] = document_metadata

insert_stmt = insert(Document).values(
[model_to_dict(Document(id=doc_id)) for doc_id in seen_document_ids]
insert_stmt = insert(DbDocument).values(
[
model_to_dict(
DbDocument(
id=doc.document_id,
boost=DEFAULT_BOOST,
hidden=False,
semantic_id=doc.semantic_identifier,
link=doc.first_link,
)
)
for doc in seen_documents.values()
]
)
# for now, there are no columns to update. If more metadata is added, then this
# needs to change to an `on_conflict_do_update`
Expand Down Expand Up @@ -120,7 +133,8 @@ def upsert_document_by_connector_credential_pair(


def upsert_documents_complete(
db_session: Session, document_metadata_batch: list[DocumentMetadata]
db_session: Session,
document_metadata_batch: list[DocumentMetadata],
) -> None:
upsert_documents(db_session, document_metadata_batch)
upsert_document_by_connector_credential_pair(db_session, document_metadata_batch)
Expand All @@ -140,7 +154,7 @@ def delete_document_by_connector_credential_pair(


def delete_documents(db_session: Session, document_ids: list[str]) -> None:
db_session.execute(delete(Document).where(Document.id.in_(document_ids)))
db_session.execute(delete(DbDocument).where(DbDocument.id.in_(document_ids)))


def delete_documents_complete(db_session: Session, document_ids: list[str]) -> None:
Expand Down
Loading
Loading