From e307275774bc20f5d1c0f0b65de735d513df352e Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Tue, 22 Aug 2023 18:11:31 -0700 Subject: [PATCH] Add support for multiple indexing workers (#322) --- backend/danswer/background/update.py | 309 +++++++++++++----- backend/danswer/configs/app_configs.py | 5 + .../danswer/db/connector_credential_pair.py | 6 +- backend/danswer/db/index_attempt.py | 7 + backend/danswer/listeners/slack_listener.py | 15 +- backend/danswer/utils/logger.py | 42 ++- backend/requirements/default.txt | 2 + .../docker_compose/docker-compose.dev.yml | 1 + 8 files changed, 293 insertions(+), 94 deletions(-) diff --git a/backend/danswer/background/update.py b/backend/danswer/background/update.py index c5617da1566..b0491a288b9 100755 --- a/backend/danswer/background/update.py +++ b/backend/danswer/background/update.py @@ -1,10 +1,16 @@ +import logging import time from datetime import datetime from datetime import timezone +from dask.distributed import Client +from dask.distributed import Future +from distributed import LocalCluster from sqlalchemy.orm import Session +from danswer.configs.app_configs import NUM_INDEXING_WORKERS from danswer.connectors.factory import instantiate_connector +from danswer.connectors.interfaces import GenerateDocumentsOutput from danswer.connectors.interfaces import LoadConnector from danswer.connectors.interfaces import PollConnector from danswer.connectors.models import IndexAttemptMetadata @@ -18,6 +24,7 @@ from danswer.db.engine import get_db_current_time from danswer.db.engine import get_sqlalchemy_engine from danswer.db.index_attempt import create_index_attempt +from danswer.db.index_attempt import get_index_attempt from danswer.db.index_attempt import get_inprogress_index_attempts from danswer.db.index_attempt import get_last_attempt from danswer.db.index_attempt import get_not_started_index_attempts @@ -28,6 +35,7 @@ from danswer.db.models import Connector from danswer.db.models import IndexAttempt from danswer.db.models import IndexingStatus +from danswer.utils.logger import IndexAttemptSingleton from danswer.utils.logger import setup_logger logger = setup_logger() @@ -45,7 +53,7 @@ def should_create_new_indexing( return time_since_index.total_seconds() >= connector.refresh_freq -def create_indexing_jobs(db_session: Session) -> None: +def create_indexing_jobs(db_session: Session, existing_jobs: dict[int, Future]) -> None: connectors = fetch_connectors(db_session) # clean up in-progress jobs that were never completed @@ -53,11 +61,11 @@ def create_indexing_jobs(db_session: Session) -> None: in_progress_indexing_attempts = get_inprogress_index_attempts( connector.id, db_session ) - if in_progress_indexing_attempts: - logger.error("Found incomplete indexing attempts") - - # Currently single threaded so any still in-progress must have errored for attempt in in_progress_indexing_attempts: + # if a job is still going, don't touch it + if attempt.id in existing_jobs: + continue + logger.warning( f"Marking in-progress attempt 'connector: {attempt.connector_id}, " f"credential: {attempt.credential_id}' as failed" @@ -69,12 +77,10 @@ def create_indexing_jobs(db_session: Session) -> None: ) if attempt.connector_id and attempt.credential_id: update_connector_credential_pair( + db_session=db_session, connector_id=attempt.connector_id, credential_id=attempt.credential_id, attempt_status=IndexingStatus.FAILED, - net_docs=None, - run_dt=None, - db_session=db_session, ) # potentially kick off new runs @@ -91,41 +97,69 @@ def create_indexing_jobs(db_session: Session) -> None: create_index_attempt(connector.id, credential.id, db_session) update_connector_credential_pair( + db_session=db_session, connector_id=connector.id, credential_id=credential.id, attempt_status=IndexingStatus.NOT_STARTED, - net_docs=None, - run_dt=None, - db_session=db_session, ) -def run_indexing_jobs(db_session: Session) -> None: - indexing_pipeline = build_indexing_pipeline() +def cleanup_indexing_jobs( + db_session: Session, existing_jobs: dict[int, Future] +) -> dict[int, Future]: + existing_jobs_copy = existing_jobs.copy() - new_indexing_attempts = get_not_started_index_attempts(db_session) - logger.info(f"Found {len(new_indexing_attempts)} new indexing tasks.") - for attempt in new_indexing_attempts: - if attempt.connector is None: - logger.warning( - f"Skipping index attempt as Connector has been deleted: {attempt}" + for attempt_id, job in existing_jobs.items(): + if not job.done(): + continue + + # cleanup completed job + job.release() + del existing_jobs_copy[attempt_id] + index_attempt = get_index_attempt( + db_session=db_session, index_attempt_id=attempt_id + ) + if not index_attempt: + logger.error( + f"Unable to find IndexAttempt for ID '{attempt_id}' when cleaning " + "up indexing jobs" ) - mark_attempt_failed(attempt, db_session, failure_reason="Connector is null") continue - if attempt.credential is None: + + if index_attempt.status == IndexingStatus.IN_PROGRESS: logger.warning( - f"Skipping index attempt as Credential has been deleted: {attempt}" + f"Marking in-progress attempt 'connector: {index_attempt.connector_id}, " + f"credential: {index_attempt.credential_id}' as failed" ) mark_attempt_failed( - attempt, db_session, failure_reason="Credential is null" + index_attempt=index_attempt, + db_session=db_session, + failure_reason="Stopped mid run, likely due to the background process being killed", ) - continue - logger.info( - f"Starting new indexing attempt for connector: '{attempt.connector.name}', " - f"with config: '{attempt.connector.connector_specific_config}', and " - f"with credentials: '{attempt.credential_id}'" - ) + if index_attempt.connector_id and index_attempt.credential_id: + update_connector_credential_pair( + db_session=db_session, + connector_id=index_attempt.connector_id, + credential_id=index_attempt.credential_id, + attempt_status=IndexingStatus.FAILED, + ) + + return existing_jobs_copy + + +def _run_indexing( + db_session: Session, + index_attempt: IndexAttempt, +) -> None: + """ + 1. Get documents which are either new or updated from specified application + 2. Embed and index these documents into the chosen datastores (e.g. Qdrant / Typesense or Vespa) + 3. Updates Postgres to record the indexed documents + the outcome of this run + """ + def _get_document_generator( + db_session: Session, attempt: IndexAttempt + ) -> tuple[GenerateDocumentsOutput, float]: # "official" timestamp for this run # used for setting time bounds when fetching updates from apps and # is stored in the DB as the last successful run time if this run succeeds @@ -133,67 +167,70 @@ def run_indexing_jobs(db_session: Session) -> None: run_dt = datetime.fromtimestamp(run_time, tz=timezone.utc) run_time_str = run_dt.strftime("%Y-%m-%d %H:%M:%S") - mark_attempt_in_progress(attempt, db_session) - - db_connector = attempt.connector - db_credential = attempt.credential - task = db_connector.input_type - - update_connector_credential_pair( - connector_id=db_connector.id, - credential_id=db_credential.id, - attempt_status=IndexingStatus.IN_PROGRESS, - net_docs=None, - run_dt=None, - db_session=db_session, - ) + task = attempt.connector.input_type try: runnable_connector, new_credential_json = instantiate_connector( - db_connector.source, + attempt.connector.source, task, - db_connector.connector_specific_config, - db_credential.credential_json, + attempt.connector.connector_specific_config, + attempt.credential.credential_json, ) if new_credential_json is not None: backend_update_credential_json( - db_credential, new_credential_json, db_session + attempt.credential, new_credential_json, db_session ) except Exception as e: logger.exception(f"Unable to instantiate connector due to {e}") - disable_connector(db_connector.id, db_session) - continue - - net_doc_change = 0 - try: - if task == InputType.LOAD_STATE: - assert isinstance(runnable_connector, LoadConnector) - doc_batch_generator = runnable_connector.load_from_state() - - elif task == InputType.POLL: - assert isinstance(runnable_connector, PollConnector) - if attempt.connector_id is None or attempt.credential_id is None: - raise ValueError( - f"Polling attempt {attempt.id} is missing connector_id or credential_id, " - f"can't fetch time range." - ) - last_run_time = get_last_successful_attempt_time( - attempt.connector_id, attempt.credential_id, db_session - ) - last_run_time_str = datetime.fromtimestamp( - last_run_time, tz=timezone.utc - ).strftime("%Y-%m-%d %H:%M:%S") - logger.info( - f"Polling for updates between {last_run_time_str} and {run_time_str}" - ) - doc_batch_generator = runnable_connector.poll_source( - start=last_run_time, end=run_time + disable_connector(attempt.connector.id, db_session) + raise e + + if task == InputType.LOAD_STATE: + assert isinstance(runnable_connector, LoadConnector) + doc_batch_generator = runnable_connector.load_from_state() + + elif task == InputType.POLL: + assert isinstance(runnable_connector, PollConnector) + if attempt.connector_id is None or attempt.credential_id is None: + raise ValueError( + f"Polling attempt {attempt.id} is missing connector_id or credential_id, " + f"can't fetch time range." ) + last_run_time = get_last_successful_attempt_time( + attempt.connector_id, attempt.credential_id, db_session + ) + last_run_time_str = datetime.fromtimestamp( + last_run_time, tz=timezone.utc + ).strftime("%Y-%m-%d %H:%M:%S") + logger.info( + f"Polling for updates between {last_run_time_str} and {run_time_str}" + ) + doc_batch_generator = runnable_connector.poll_source( + start=last_run_time, end=run_time + ) - else: - # Event types cannot be handled by a background type, leave these untouched - continue + else: + # Event types cannot be handled by a background type + raise RuntimeError(f"Invalid task type: {task}") + + return doc_batch_generator, run_time + + doc_batch_generator, run_time = _get_document_generator(db_session, index_attempt) + def _index( + db_session: Session, + attempt: IndexAttempt, + doc_batch_generator: GenerateDocumentsOutput, + run_time: float, + ) -> None: + indexing_pipeline = build_indexing_pipeline() + + run_dt = datetime.fromtimestamp(run_time, tz=timezone.utc) + db_connector = attempt.connector + db_credential = attempt.credential + + try: + net_doc_change = 0 document_count = 0 chunk_count = 0 for doc_batch in doc_batch_generator: @@ -229,12 +266,12 @@ def run_indexing_jobs(db_session: Session) -> None: mark_attempt_succeeded(attempt, db_session) update_connector_credential_pair( + db_session=db_session, connector_id=db_connector.id, credential_id=db_credential.id, attempt_status=IndexingStatus.SUCCESS, net_docs=net_doc_change, run_dt=run_dt, - db_session=db_session, ) logger.info( @@ -243,24 +280,121 @@ def run_indexing_jobs(db_session: Session) -> None: logger.info( f"Connector successfully finished, elapsed time: {time.time() - run_time} seconds" ) - except Exception as e: - logger.exception(f"Indexing job with id {attempt.id} failed due to {e}") logger.info( f"Failed connector elapsed time: {time.time() - run_time} seconds" ) mark_attempt_failed(attempt, db_session, failure_reason=str(e)) update_connector_credential_pair( - connector_id=db_connector.id, - credential_id=db_credential.id, + db_session=db_session, + connector_id=attempt.connector.id, + credential_id=attempt.credential.id, attempt_status=IndexingStatus.FAILED, net_docs=net_doc_change, run_dt=run_dt, + ) + raise e + + _index(db_session, index_attempt, doc_batch_generator, run_time) + + +def _run_indexing_entrypoint(index_attempt_id: int) -> None: + """Entrypoint for indexing run when using dask distributed. + Wraps the actual logic in a `try` block so that we can catch any exceptions + and mark the attempt as failed.""" + try: + # set the indexing attempt ID so that all log messages from this process + # will have it added as a prefix + IndexAttemptSingleton.set_index_attempt_id(index_attempt_id) + + with Session(get_sqlalchemy_engine()) as db_session: + attempt = get_index_attempt( + db_session=db_session, index_attempt_id=index_attempt_id + ) + if attempt is None: + raise RuntimeError( + f"Unable to find IndexAttempt for ID '{index_attempt_id}'" + ) + + logger.info( + f"Running indexing attempt for connector: '{attempt.connector.name}', " + f"with config: '{attempt.connector.connector_specific_config}', and " + f"with credentials: '{attempt.credential_id}'" + ) + update_connector_credential_pair( db_session=db_session, + connector_id=attempt.connector.id, + credential_id=attempt.credential.id, + attempt_status=IndexingStatus.IN_PROGRESS, ) + _run_indexing( + db_session=db_session, + index_attempt=attempt, + ) + + logger.info( + f"Completed indexing attempt for connector: '{attempt.connector.name}', " + f"with config: '{attempt.connector.connector_specific_config}', and " + f"with credentials: '{attempt.credential_id}'" + ) + except Exception as e: + logger.exception(f"Indexing job with ID '{index_attempt_id}' failed due to {e}") + + +def kickoff_indexing_jobs( + db_session: Session, + existing_jobs: dict[int, Future], + client: Client, +) -> dict[int, Future]: + existing_jobs_copy = existing_jobs.copy() -def update_loop(delay: int = 10) -> None: + new_indexing_attempts = get_not_started_index_attempts(db_session) + logger.info(f"Found {len(new_indexing_attempts)} new indexing tasks.") + + if not new_indexing_attempts: + return existing_jobs + + for attempt in new_indexing_attempts: + if attempt.connector is None: + logger.warning( + f"Skipping index attempt as Connector has been deleted: {attempt}" + ) + mark_attempt_failed(attempt, db_session, failure_reason="Connector is null") + continue + if attempt.credential is None: + logger.warning( + f"Skipping index attempt as Credential has been deleted: {attempt}" + ) + mark_attempt_failed( + attempt, db_session, failure_reason="Credential is null" + ) + continue + + logger.info( + f"Kicking off indexing attempt for connector: '{attempt.connector.name}', " + f"with config: '{attempt.connector.connector_specific_config}', and " + f"with credentials: '{attempt.credential_id}'" + ) + mark_attempt_in_progress(attempt, db_session) + run = client.submit(_run_indexing_entrypoint, attempt.id, pure=False) + existing_jobs_copy[attempt.id] = run + + return existing_jobs_copy + + +def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None: + cluster = LocalCluster( + n_workers=num_workers, + threads_per_worker=1, + # there are warning about high memory usage + "Event loop unresponsive" + # which are not relevant to us since our workers are expected to use a + # lot of memory + involve CPU intensive tasks that will not relinquish + # the event loop + silence_logs=logging.ERROR, + ) + client = Client(cluster) + existing_jobs: dict[int, Future] = {} engine = get_sqlalchemy_engine() while True: start = time.time() @@ -268,8 +402,13 @@ def update_loop(delay: int = 10) -> None: logger.info(f"Running update, current UTC time: {start_time_utc}") try: with Session(engine, expire_on_commit=False) as db_session: - create_indexing_jobs(db_session) - run_indexing_jobs(db_session) + existing_jobs = cleanup_indexing_jobs( + db_session=db_session, existing_jobs=existing_jobs + ) + create_indexing_jobs(db_session=db_session, existing_jobs=existing_jobs) + existing_jobs = kickoff_indexing_jobs( + db_session=db_session, existing_jobs=existing_jobs, client=client + ) except Exception as e: logger.exception(f"Failed to run update due to {e}") sleep_time = delay - (time.time() - start) diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index 9b6f455a84d..4aa0a95196a 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -163,6 +163,11 @@ CONTINUE_ON_CONNECTOR_FAILURE = os.environ.get( "CONTINUE_ON_CONNECTOR_FAILURE", "" ).lower() not in ["false", ""] +# Controls how many worker processes we spin up to index documents in the +# background. This is useful for speeding up indexing, but does require a +# fairly large amount of memory in order to increase substantially, since +# each worker loads the embedding models into memory. +NUM_INDEXING_WORKERS = int(os.environ.get("NUM_INDEXING_WORKERS") or 1) ##### diff --git a/backend/danswer/db/connector_credential_pair.py b/backend/danswer/db/connector_credential_pair.py index 80c4884f79f..bc7a4cc7f54 100644 --- a/backend/danswer/db/connector_credential_pair.py +++ b/backend/danswer/db/connector_credential_pair.py @@ -57,12 +57,12 @@ def get_last_successful_attempt_time( def update_connector_credential_pair( + db_session: Session, connector_id: int, credential_id: int, attempt_status: IndexingStatus, - net_docs: int | None, - run_dt: datetime | None, - db_session: Session, + net_docs: int | None = None, + run_dt: datetime | None = None, ) -> None: cc_pair = get_connector_credential_pair(connector_id, credential_id, db_session) if not cc_pair: diff --git a/backend/danswer/db/index_attempt.py b/backend/danswer/db/index_attempt.py index 6dd8653d517..c4697f5d3a0 100644 --- a/backend/danswer/db/index_attempt.py +++ b/backend/danswer/db/index_attempt.py @@ -18,6 +18,13 @@ logger = setup_logger() +def get_index_attempt( + db_session: Session, index_attempt_id: int +) -> IndexAttempt | None: + stmt = select(IndexAttempt).where(IndexAttempt.id == index_attempt_id) + return db_session.scalars(stmt).first() + + def create_index_attempt( connector_id: int, credential_id: int, diff --git a/backend/danswer/listeners/slack_listener.py b/backend/danswer/listeners/slack_listener.py index e96a6457561..a0fbffc7c47 100644 --- a/backend/danswer/listeners/slack_listener.py +++ b/backend/danswer/listeners/slack_listener.py @@ -1,3 +1,4 @@ +import logging import os from collections.abc import Callable from functools import wraps @@ -183,7 +184,12 @@ def process_slack_event(client: SocketModeClient, req: SocketModeRequest) -> Non # TODO: message should be enqueued and processed elsewhere, # but doing it here for now for simplicity - @retry(tries=DANSWER_BOT_NUM_RETRIES, delay=0.25, backoff=2, logger=logger) + @retry( + tries=DANSWER_BOT_NUM_RETRIES, + delay=0.25, + backoff=2, + logger=cast(logging.Logger, logger), + ) def _get_answer(question: QuestionRequest) -> QAResponse: answer = answer_question( question=question, @@ -227,7 +233,12 @@ def _get_answer(question: QuestionRequest) -> QAResponse: else: text = f"{answer.answer}\n\n*Warning*: no sources were quoted for this answer, so it may be unreliable 😔\n\n{top_documents_str_with_header}" - @retry(tries=DANSWER_BOT_NUM_RETRIES, delay=0.25, backoff=2, logger=logger) + @retry( + tries=DANSWER_BOT_NUM_RETRIES, + delay=0.25, + backoff=2, + logger=cast(logging.Logger, logger), + ) def _respond_in_thread( channel: str, text: str, diff --git a/backend/danswer/utils/logger.py b/backend/danswer/utils/logger.py index ec061450075..86173971d8e 100644 --- a/backend/danswer/utils/logger.py +++ b/backend/danswer/utils/logger.py @@ -1,9 +1,26 @@ import logging -from logging import Logger +from collections.abc import MutableMapping +from typing import Any from danswer.configs.app_configs import LOG_LEVEL +class IndexAttemptSingleton: + """Used to tell if this process is an indexing job, and if so what is the + unique identifier for this indexing attempt. For things like the API server, + main background job (scheduler), etc. this will not be used.""" + + _INDEX_ATTEMPT_ID: None | int = None + + @classmethod + def get_index_attempt_id(cls) -> None | int: + return cls._INDEX_ATTEMPT_ID + + @classmethod + def set_index_attempt_id(cls, index_attempt_id: int) -> None: + cls._INDEX_ATTEMPT_ID = index_attempt_id + + def get_log_level_from_str(log_level_str: str = LOG_LEVEL) -> int: log_level_dict = { "CRITICAL": logging.CRITICAL, @@ -17,14 +34,31 @@ def get_log_level_from_str(log_level_str: str = LOG_LEVEL) -> int: return log_level_dict.get(log_level_str.upper(), logging.INFO) +class _IndexAttemptLoggingAdapter(logging.LoggerAdapter): + """This is used to globally add the index attempt id to all log messages + during indexing by workers. This is done so that the logs can be filtered + by index attempt ID to get a better idea of what happened during a specific + indexing attempt. If the index attempt ID is not set, then this adapter + is a no-op.""" + + def process( + self, msg: str, kwargs: MutableMapping[str, Any] + ) -> tuple[str, MutableMapping[str, Any]]: + attempt_id = IndexAttemptSingleton.get_index_attempt_id() + if attempt_id is None: + return msg, kwargs + + return f"[Attempt ID: {attempt_id}] {msg}", kwargs + + def setup_logger( name: str = __name__, log_level: int = get_log_level_from_str() -) -> Logger: +) -> logging.LoggerAdapter: logger = logging.getLogger(name) # If the logger already has handlers, assume it was already configured and return it. if logger.handlers: - return logger + return _IndexAttemptLoggingAdapter(logger) logger.setLevel(log_level) @@ -39,4 +73,4 @@ def setup_logger( logger.addHandler(handler) - return logger + return _IndexAttemptLoggingAdapter(logger) diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 1cbb84d9581..235291b31ba 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -2,6 +2,8 @@ alembic==1.10.4 asyncpg==0.27.0 atlassian-python-api==3.37.0 beautifulsoup4==4.12.0 +dask==2023.8.1 +distributed==2023.8.1 python-dateutil==2.8.2 fastapi==0.95.0 fastapi-users==11.0.0 diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index 130077ac921..5bf92a2b93e 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -69,6 +69,7 @@ services: - API_VERSION_OPENAI=${API_VERSION_OPENAI:-} - AZURE_DEPLOYMENT_ID=${AZURE_DEPLOYMENT_ID:-} - CONTINUE_ON_CONNECTOR_FAILURE=${CONTINUE_ON_CONNECTOR_FAILURE:-} + - NUM_INDEXING_WORKERS=${NUM_INDEXING_WORKERS:-} - DANSWER_BOT_SLACK_APP_TOKEN=${DANSWER_BOT_SLACK_APP_TOKEN:-} - DANSWER_BOT_SLACK_BOT_TOKEN=${DANSWER_BOT_SLACK_BOT_TOKEN:-} - LOG_LEVEL=${LOG_LEVEL:-info}