diff --git a/backend/alembic.ini b/backend/alembic.ini index 10ae5cfdd27..599c46fadd7 100644 --- a/backend/alembic.ini +++ b/backend/alembic.ini @@ -1,6 +1,6 @@ # A generic, single database configuration. -[alembic] +[DEFAULT] # path to migration scripts script_location = alembic @@ -47,7 +47,8 @@ prepend_sys_path = . # version_path_separator = : # version_path_separator = ; # version_path_separator = space -version_path_separator = os # Use os.pathsep. Default configuration used for new projects. +version_path_separator = os +# Use os.pathsep. Default configuration used for new projects. # set to 'true' to search source files recursively # in each "version_locations" directory @@ -106,3 +107,12 @@ formatter = generic [formatter_generic] format = %(levelname)-5.5s [%(name)s] %(message)s datefmt = %H:%M:%S + + +[alembic] +script_location = alembic +version_locations = %(script_location)s/versions + +[schema_private] +script_location = alembic_tenants +version_locations = %(script_location)s/versions diff --git a/backend/alembic/env.py b/backend/alembic/env.py index d7ac37af562..afa5a9669c1 100644 --- a/backend/alembic/env.py +++ b/backend/alembic/env.py @@ -1,21 +1,22 @@ +from typing import Any import asyncio from logging.config import fileConfig from alembic import context -from danswer.db.engine import build_connection_string -from danswer.db.models import Base from sqlalchemy import pool from sqlalchemy.engine import Connection from sqlalchemy.ext.asyncio import create_async_engine -from celery.backends.database.session import ResultModelBase # type: ignore -from sqlalchemy.schema import SchemaItem from sqlalchemy.sql import text +from danswer.configs.app_configs import MULTI_TENANT +from danswer.db.engine import build_connection_string +from danswer.db.models import Base +from celery.backends.database.session import ResultModelBase # type: ignore + # Alembic Config object config = context.config # Interpret the config file for Python logging. -# This line sets up loggers basically. if config.config_file_name is not None and config.attributes.get( "configure_logger", True ): @@ -35,8 +36,7 @@ def get_schema_options() -> tuple[str, bool]: for pair in arg.split(","): if "=" in pair: key, value = pair.split("=", 1) - x_args[key] = value - + x_args[key.strip()] = value.strip() schema_name = x_args.get("schema", "public") create_schema = x_args.get("create_schema", "true").lower() == "true" return schema_name, create_schema @@ -46,11 +46,7 @@ def get_schema_options() -> tuple[str, bool]: def include_object( - object: SchemaItem, - name: str, - type_: str, - reflected: bool, - compare_to: SchemaItem | None, + object: Any, name: str, type_: str, reflected: bool, compare_to: Any ) -> bool: if type_ == "table" and name in EXCLUDE_TABLES: return False @@ -59,7 +55,6 @@ def include_object( def run_migrations_offline() -> None: """Run migrations in 'offline' mode. - This configures the context with just a URL and not an Engine, though an Engine is acceptable here as well. By skipping the Engine creation @@ -67,17 +62,18 @@ def run_migrations_offline() -> None: Calls to context.execute() here emit the given string to the script output. """ + schema_name, _ = get_schema_options() url = build_connection_string() - schema, _ = get_schema_options() context.configure( url=url, target_metadata=target_metadata, # type: ignore literal_binds=True, include_object=include_object, - dialect_opts={"paramstyle": "named"}, - version_table_schema=schema, + version_table_schema=schema_name, include_schemas=True, + script_location=config.get_main_option("script_location"), + dialect_opts={"paramstyle": "named"}, ) with context.begin_transaction(): @@ -85,20 +81,30 @@ def run_migrations_offline() -> None: def do_run_migrations(connection: Connection) -> None: - schema, create_schema = get_schema_options() + schema_name, create_schema = get_schema_options() + + if MULTI_TENANT and schema_name == "public": + raise ValueError( + "Cannot run default migrations in public schema when multi-tenancy is enabled. " + "Please specify a tenant-specific schema." + ) + if create_schema: - connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema}"')) + connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"')) connection.execute(text("COMMIT")) - connection.execute(text(f'SET search_path TO "{schema}"')) + # Set search_path to the target schema + connection.execute(text(f'SET search_path TO "{schema_name}"')) context.configure( connection=connection, target_metadata=target_metadata, # type: ignore - version_table_schema=schema, + include_object=include_object, + version_table_schema=schema_name, include_schemas=True, compare_type=True, compare_server_default=True, + script_location=config.get_main_option("script_location"), ) with context.begin_transaction(): @@ -106,7 +112,6 @@ def do_run_migrations(connection: Connection) -> None: async def run_async_migrations() -> None: - """Run migrations in 'online' mode.""" connectable = create_async_engine( build_connection_string(), poolclass=pool.NullPool, @@ -119,7 +124,6 @@ async def run_async_migrations() -> None: def run_migrations_online() -> None: - """Run migrations in 'online' mode.""" asyncio.run(run_async_migrations()) diff --git a/backend/alembic/versions/da4c21c69164_chosen_assistants_changed_to_jsonb.py b/backend/alembic/versions/da4c21c69164_chosen_assistants_changed_to_jsonb.py index 95b53cbeb41..8e0a8e6072d 100644 --- a/backend/alembic/versions/da4c21c69164_chosen_assistants_changed_to_jsonb.py +++ b/backend/alembic/versions/da4c21c69164_chosen_assistants_changed_to_jsonb.py @@ -20,7 +20,7 @@ def upgrade() -> None: conn = op.get_bind() existing_ids_and_chosen_assistants = conn.execute( - sa.text("select id, chosen_assistants from public.user") + sa.text('select id, chosen_assistants from "user"') ) op.drop_column( "user", @@ -37,7 +37,7 @@ def upgrade() -> None: for id, chosen_assistants in existing_ids_and_chosen_assistants: conn.execute( sa.text( - "update public.user set chosen_assistants = :chosen_assistants where id = :id" + 'update "user" set chosen_assistants = :chosen_assistants where id = :id' ), {"chosen_assistants": json.dumps(chosen_assistants), "id": id}, ) @@ -46,7 +46,7 @@ def upgrade() -> None: def downgrade() -> None: conn = op.get_bind() existing_ids_and_chosen_assistants = conn.execute( - sa.text("select id, chosen_assistants from public.user") + sa.text('select id, chosen_assistants from "user"') ) op.drop_column( "user", @@ -59,7 +59,7 @@ def downgrade() -> None: for id, chosen_assistants in existing_ids_and_chosen_assistants: conn.execute( sa.text( - "update public.user set chosen_assistants = :chosen_assistants where id = :id" + 'update "user" set chosen_assistants = :chosen_assistants where id = :id' ), {"chosen_assistants": chosen_assistants, "id": id}, ) diff --git a/backend/alembic_tenants/README.md b/backend/alembic_tenants/README.md new file mode 100644 index 00000000000..f075b958305 --- /dev/null +++ b/backend/alembic_tenants/README.md @@ -0,0 +1,3 @@ +These files are for public table migrations when operating with multi tenancy. + +If you are not a Danswer developer, you can ignore this directory entirely. \ No newline at end of file diff --git a/backend/alembic_tenants/env.py b/backend/alembic_tenants/env.py new file mode 100644 index 00000000000..f0f1178ce09 --- /dev/null +++ b/backend/alembic_tenants/env.py @@ -0,0 +1,111 @@ +import asyncio +from logging.config import fileConfig + +from sqlalchemy import pool +from sqlalchemy.engine import Connection +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.schema import SchemaItem + +from alembic import context +from danswer.db.engine import build_connection_string +from danswer.db.models import PublicBase + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None and config.attributes.get( + "configure_logger", True +): + fileConfig(config.config_file_name) + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +target_metadata = [PublicBase.metadata] + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + +EXCLUDE_TABLES = {"kombu_queue", "kombu_message"} + + +def include_object( + object: SchemaItem, + name: str, + type_: str, + reflected: bool, + compare_to: SchemaItem | None, +) -> bool: + if type_ == "table" and name in EXCLUDE_TABLES: + return False + return True + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = build_connection_string() + context.configure( + url=url, + target_metadata=target_metadata, # type: ignore + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def do_run_migrations(connection: Connection) -> None: + context.configure( + connection=connection, + target_metadata=target_metadata, # type: ignore + include_object=include_object, + ) # type: ignore + + with context.begin_transaction(): + context.run_migrations() + + +async def run_async_migrations() -> None: + """In this scenario we need to create an Engine + and associate a connection with the context. + + """ + + connectable = create_async_engine( + build_connection_string(), + poolclass=pool.NullPool, + ) + + async with connectable.connect() as connection: + await connection.run_sync(do_run_migrations) + + await connectable.dispose() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode.""" + + asyncio.run(run_async_migrations()) + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/backend/alembic_tenants/script.py.mako b/backend/alembic_tenants/script.py.mako new file mode 100644 index 00000000000..55df2863d20 --- /dev/null +++ b/backend/alembic_tenants/script.py.mako @@ -0,0 +1,24 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/backend/alembic_tenants/versions/14a83a331951_create_usertenantmapping_table.py b/backend/alembic_tenants/versions/14a83a331951_create_usertenantmapping_table.py new file mode 100644 index 00000000000..f8f3016bab1 --- /dev/null +++ b/backend/alembic_tenants/versions/14a83a331951_create_usertenantmapping_table.py @@ -0,0 +1,24 @@ +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "14a83a331951" +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "user_tenant_mapping", + sa.Column("email", sa.String(), nullable=False), + sa.Column("tenant_id", sa.String(), nullable=False), + sa.UniqueConstraint("email", "tenant_id", name="uq_user_tenant"), + sa.UniqueConstraint("email", name="uq_email"), + schema="public", + ) + + +def downgrade() -> None: + op.drop_table("user_tenant_mapping", schema="public") diff --git a/backend/danswer/auth/schemas.py b/backend/danswer/auth/schemas.py index db8a97ceb04..9c81899a421 100644 --- a/backend/danswer/auth/schemas.py +++ b/backend/danswer/auth/schemas.py @@ -34,6 +34,7 @@ class UserRead(schemas.BaseUser[uuid.UUID]): class UserCreate(schemas.BaseUserCreate): role: UserRole = UserRole.BASIC has_web_login: bool | None = True + tenant_id: str | None = None class UserUpdate(schemas.BaseUserUpdate): diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index 81607aab884..3fc117b31a0 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -26,11 +26,14 @@ from fastapi_users import UUIDIDMixin from fastapi_users.authentication import AuthenticationBackend from fastapi_users.authentication import CookieTransport +from fastapi_users.authentication import JWTStrategy from fastapi_users.authentication import Strategy from fastapi_users.authentication.strategy.db import AccessTokenDatabase from fastapi_users.authentication.strategy.db import DatabaseStrategy from fastapi_users.openapi import OpenAPIResponseType from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase +from sqlalchemy import select +from sqlalchemy.orm import attributes from sqlalchemy.orm import Session from danswer.auth.invited_users import get_invited_users @@ -42,7 +45,9 @@ from danswer.configs.app_configs import DISABLE_AUTH from danswer.configs.app_configs import EMAIL_FROM from danswer.configs.app_configs import EXPECTED_API_KEY +from danswer.configs.app_configs import MULTI_TENANT from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION +from danswer.configs.app_configs import SECRET_JWT_KEY from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS from danswer.configs.app_configs import SMTP_PASS from danswer.configs.app_configs import SMTP_PORT @@ -60,15 +65,21 @@ from danswer.db.auth import get_default_admin_user_emails from danswer.db.auth import get_user_count from danswer.db.auth import get_user_db +from danswer.db.auth import SQLAlchemyUserAdminDB +from danswer.db.engine import get_async_session_with_tenant from danswer.db.engine import get_session +from danswer.db.engine import get_session_with_tenant from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import AccessToken +from danswer.db.models import OAuthAccount from danswer.db.models import User +from danswer.db.models import UserTenantMapping from danswer.db.users import get_user_by_email from danswer.utils.logger import setup_logger from danswer.utils.telemetry import optional_telemetry from danswer.utils.telemetry import RecordType from danswer.utils.variable_functionality import fetch_versioned_implementation +from shared_configs.configs import current_tenant_id logger = setup_logger() @@ -136,8 +147,8 @@ def verify_email_is_invited(email: str) -> None: raise PermissionError("User not on allowed user whitelist") -def verify_email_in_whitelist(email: str) -> None: - with Session(get_sqlalchemy_engine()) as db_session: +def verify_email_in_whitelist(email: str, tenant_id: str | None = None) -> None: + with get_session_with_tenant(tenant_id) as db_session: if not get_user_by_email(email, db_session): verify_email_is_invited(email) @@ -157,6 +168,20 @@ def verify_email_domain(email: str) -> None: ) +def get_tenant_id_for_email(email: str) -> str: + if not MULTI_TENANT: + return "public" + # Implement logic to get tenant_id from the mapping table + with Session(get_sqlalchemy_engine()) as db_session: + result = db_session.execute( + select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email) + ) + tenant_id = result.scalar_one_or_none() + if tenant_id is None: + raise exceptions.UserNotExists() + return tenant_id + + def send_user_verification_email( user_email: str, token: str, @@ -221,6 +246,29 @@ async def create( raise exceptions.UserAlreadyExists() return user + async def on_after_login( + self, + user: User, + request: Request | None = None, + response: Response | None = None, + ) -> None: + if response is None or not MULTI_TENANT: + return + + tenant_id = get_tenant_id_for_email(user.email) + + tenant_token = jwt.encode( + {"tenant_id": tenant_id}, SECRET_JWT_KEY, algorithm="HS256" + ) + + response.set_cookie( + key="tenant_details", + value=tenant_token, + httponly=True, + secure=WEB_DOMAIN.startswith("https"), + samesite="lax", + ) + async def oauth_callback( self: "BaseUserManager[models.UOAP, models.ID]", oauth_name: str, @@ -234,45 +282,111 @@ async def oauth_callback( associate_by_email: bool = False, is_verified_by_default: bool = False, ) -> models.UOAP: - verify_email_in_whitelist(account_email) - verify_email_domain(account_email) - - user = await super().oauth_callback( # type: ignore - oauth_name=oauth_name, - access_token=access_token, - account_id=account_id, - account_email=account_email, - expires_at=expires_at, - refresh_token=refresh_token, - request=request, - associate_by_email=associate_by_email, - is_verified_by_default=is_verified_by_default, - ) - - # NOTE: Most IdPs have very short expiry times, and we don't want to force the user to - # re-authenticate that frequently, so by default this is disabled - if expires_at and TRACK_EXTERNAL_IDP_EXPIRY: - oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc) - await self.user_db.update(user, update_dict={"oidc_expiry": oidc_expiry}) - - # this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false` - # otherwise, the oidc expiry will always be old, and the user will never be able to login - if user.oidc_expiry and not TRACK_EXTERNAL_IDP_EXPIRY: - await self.user_db.update(user, update_dict={"oidc_expiry": None}) - - # Handle case where user has used product outside of web and is now creating an account through web - if not user.has_web_login: - await self.user_db.update( - user, - update_dict={ - "is_verified": is_verified_by_default, - "has_web_login": True, - }, + # Get tenant_id from mapping table + try: + tenant_id = ( + get_tenant_id_for_email(account_email) if MULTI_TENANT else "public" ) - user.is_verified = is_verified_by_default - user.has_web_login = True + except exceptions.UserNotExists: + raise HTTPException(status_code=401, detail="User not found") + + if not tenant_id: + raise HTTPException(status_code=401, detail="User not found") + + token = None + async with get_async_session_with_tenant(tenant_id) as db_session: + token = current_tenant_id.set(tenant_id) + # Print a list of tables in the current database session + verify_email_in_whitelist(account_email, tenant_id) + verify_email_domain(account_email) + if MULTI_TENANT: + tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount) + self.user_db = tenant_user_db + self.database = tenant_user_db + + oauth_account_dict = { + "oauth_name": oauth_name, + "access_token": access_token, + "account_id": account_id, + "account_email": account_email, + "expires_at": expires_at, + "refresh_token": refresh_token, + } + + try: + # Attempt to get user by OAuth account + user = await self.get_by_oauth_account(oauth_name, account_id) + + except exceptions.UserNotExists: + try: + # Attempt to get user by email + user = await self.get_by_email(account_email) + if not associate_by_email: + raise exceptions.UserAlreadyExists() + + user = await self.user_db.add_oauth_account( + user, oauth_account_dict + ) + + # If user not found by OAuth account or email, create a new user + except exceptions.UserNotExists: + password = self.password_helper.generate() + user_dict = { + "email": account_email, + "hashed_password": self.password_helper.hash(password), + "is_verified": is_verified_by_default, + } + + user = await self.user_db.create(user_dict) + user = await self.user_db.add_oauth_account( + user, oauth_account_dict + ) + await self.on_after_register(user, request) - return user + else: + for existing_oauth_account in user.oauth_accounts: + if ( + existing_oauth_account.account_id == account_id + and existing_oauth_account.oauth_name == oauth_name + ): + user = await self.user_db.update_oauth_account( + user, existing_oauth_account, oauth_account_dict + ) + + # NOTE: Most IdPs have very short expiry times, and we don't want to force the user to + # re-authenticate that frequently, so by default this is disabled + + if expires_at and TRACK_EXTERNAL_IDP_EXPIRY: + oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc) + await self.user_db.update( + user, update_dict={"oidc_expiry": oidc_expiry} + ) + + # Handle case where user has used product outside of web and is now creating an account through web + if not user.has_web_login: # type: ignore + await self.user_db.update( + user, + { + "is_verified": is_verified_by_default, + "has_web_login": True, + }, + ) + user.is_verified = is_verified_by_default + user.has_web_login = True # type: ignore + + # this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false` + # otherwise, the oidc expiry will always be old, and the user will never be able to login + if ( + user.oidc_expiry is not None # type: ignore + and not TRACK_EXTERNAL_IDP_EXPIRY + ): + await self.user_db.update(user, {"oidc_expiry": None}) + user.oidc_expiry = None # type: ignore + + if token: + current_tenant_id.reset(token) + + return user async def on_after_register( self, user: User, request: Optional[Request] = None @@ -303,28 +417,51 @@ async def on_after_request_verify( async def authenticate( self, credentials: OAuth2PasswordRequestForm ) -> Optional[User]: - try: - user = await self.get_by_email(credentials.username) - except exceptions.UserNotExists: + email = credentials.username + + # Get tenant_id from mapping table + + tenant_id = get_tenant_id_for_email(email) + if not tenant_id: + # User not found in mapping self.password_helper.hash(credentials.password) return None - if not user.has_web_login: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD", + # Create a tenant-specific session + async with get_async_session_with_tenant(tenant_id) as tenant_session: + tenant_user_db: SQLAlchemyUserDatabase = SQLAlchemyUserDatabase( + tenant_session, User ) + self.user_db = tenant_user_db - verified, updated_password_hash = self.password_helper.verify_and_update( - credentials.password, user.hashed_password - ) - if not verified: - return None + # Proceed with authentication + try: + user = await self.get_by_email(email) - if updated_password_hash is not None: - await self.user_db.update(user, {"hashed_password": updated_password_hash}) + except exceptions.UserNotExists: + self.password_helper.hash(credentials.password) + return None - return user + has_web_login = attributes.get_attribute(user, "has_web_login") + + if not has_web_login: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD", + ) + + verified, updated_password_hash = self.password_helper.verify_and_update( + credentials.password, user.hashed_password + ) + if not verified: + return None + + if updated_password_hash is not None: + await self.user_db.update( + user, {"hashed_password": updated_password_hash} + ) + + return user async def get_user_manager( @@ -339,20 +476,26 @@ async def get_user_manager( ) +def get_jwt_strategy() -> JWTStrategy: + return JWTStrategy( + secret=USER_AUTH_SECRET, + lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS, + ) + + def get_database_strategy( access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db), ) -> DatabaseStrategy: - strategy = DatabaseStrategy( + return DatabaseStrategy( access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore ) - return strategy auth_backend = AuthenticationBackend( - name="database", + name="jwt" if MULTI_TENANT else "database", transport=cookie_transport, - get_strategy=get_database_strategy, -) + get_strategy=get_jwt_strategy if MULTI_TENANT else get_database_strategy, # type: ignore +) # type: ignore class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]): @@ -366,9 +509,11 @@ def get_logout_router( This way the login router does not need to be included """ router = APIRouter() + get_current_user_token = self.authenticator.current_user_token( active=True, verified=requires_verification ) + logout_responses: OpenAPIResponseType = { **{ status.HTTP_401_UNAUTHORIZED: { @@ -415,8 +560,8 @@ async def optional_user_( async def optional_user( request: Request, - user: User | None = Depends(optional_fastapi_current_user), db_session: Session = Depends(get_session), + user: User | None = Depends(optional_fastapi_current_user), ) -> User | None: versioned_fetch_user = fetch_versioned_implementation( "danswer.auth.users", "optional_user_" diff --git a/backend/danswer/background/celery/celery_app.py b/backend/danswer/background/celery/celery_app.py index 5d5450315b5..0e9fb00b1fd 100644 --- a/backend/danswer/background/celery/celery_app.py +++ b/backend/danswer/background/celery/celery_app.py @@ -23,6 +23,7 @@ from danswer.background.celery.celery_redis import RedisDocumentSet from danswer.background.celery.celery_redis import RedisUserGroup from danswer.background.celery.celery_utils import celery_is_worker_primary +from danswer.background.update import get_all_tenant_ids from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT from danswer.configs.constants import DanswerCeleryPriority from danswer.configs.constants import DanswerRedisLocks @@ -70,7 +71,6 @@ def celery_task_postrun( return task_logger.debug(f"Task {task.name} (ID: {task_id}) completed with state: {state}") - # logger.debug(f"Result: {retval}") if state not in READY_STATES: return @@ -437,48 +437,58 @@ def stop(self, worker: Any) -> None: ##### # Celery Beat (Periodic Tasks) Settings ##### -celery_app.conf.beat_schedule = { - "check-for-vespa-sync": { + +tenant_ids = get_all_tenant_ids() + +tasks_to_schedule = [ + { + "name": "check-for-vespa-sync", "task": "check_for_vespa_sync_task", "schedule": timedelta(seconds=5), "options": {"priority": DanswerCeleryPriority.HIGH}, }, -} -celery_app.conf.beat_schedule.update( { - "check-for-connector-deletion-task": { - "task": "check_for_connector_deletion_task", - # don't need to check too often, since we kick off a deletion initially - # during the API call that actually marks the CC pair for deletion - "schedule": timedelta(seconds=60), - "options": {"priority": DanswerCeleryPriority.HIGH}, - }, - } -) -celery_app.conf.beat_schedule.update( + "name": "check-for-connector-deletion", + "task": "check_for_connector_deletion_task", + "schedule": timedelta(seconds=60), + "options": {"priority": DanswerCeleryPriority.HIGH}, + }, { - "check-for-prune": { - "task": "check_for_prune_task_2", - "schedule": timedelta(seconds=60), - "options": {"priority": DanswerCeleryPriority.HIGH}, - }, - } -) -celery_app.conf.beat_schedule.update( + "name": "check-for-prune", + "task": "check_for_prune_task_2", + "schedule": timedelta(seconds=10), + "options": {"priority": DanswerCeleryPriority.HIGH}, + }, { - "kombu-message-cleanup": { - "task": "kombu_message_cleanup_task", - "schedule": timedelta(seconds=3600), - "options": {"priority": DanswerCeleryPriority.LOWEST}, - }, - } -) -celery_app.conf.beat_schedule.update( + "name": "kombu-message-cleanup", + "task": "kombu_message_cleanup_task", + "schedule": timedelta(seconds=3600), + "options": {"priority": DanswerCeleryPriority.LOWEST}, + }, { - "monitor-vespa-sync": { - "task": "monitor_vespa_sync", - "schedule": timedelta(seconds=5), - "options": {"priority": DanswerCeleryPriority.HIGH}, - }, - } -) + "name": "monitor-vespa-sync", + "task": "monitor_vespa_sync", + "schedule": timedelta(seconds=5), + "options": {"priority": DanswerCeleryPriority.HIGH}, + }, +] + +# Build the celery beat schedule dynamically +beat_schedule = {} + +for tenant_id in tenant_ids: + for task in tasks_to_schedule: + task_name = f"{task['name']}-{tenant_id}" # Unique name for each scheduled task + beat_schedule[task_name] = { + "task": task["task"], + "schedule": task["schedule"], + "options": task["options"], + "args": (tenant_id,), # Must pass tenant_id as an argument + } + +# Include any existing beat schedules +existing_beat_schedule = celery_app.conf.beat_schedule or {} +beat_schedule.update(existing_beat_schedule) + +# Update the Celery app configuration once +celery_app.conf.beat_schedule = beat_schedule diff --git a/backend/danswer/background/celery/celery_redis.py b/backend/danswer/background/celery/celery_redis.py index f08bfd17e2f..1506a4b9be1 100644 --- a/backend/danswer/background/celery/celery_redis.py +++ b/backend/danswer/background/celery/celery_redis.py @@ -107,6 +107,7 @@ def generate_tasks( db_session: Session, redis_client: Redis, lock: redis.lock.Lock, + tenant_id: str | None, ) -> int | None: pass @@ -122,6 +123,7 @@ def generate_tasks( db_session: Session, redis_client: Redis, lock: redis.lock.Lock, + tenant_id: str | None, ) -> int | None: last_lock_time = time.monotonic() @@ -146,7 +148,7 @@ def generate_tasks( result = celery_app.send_task( "vespa_metadata_sync_task", - kwargs=dict(document_id=doc.id), + kwargs=dict(document_id=doc.id, tenant_id=tenant_id), queue=DanswerCeleryQueues.VESPA_METADATA_SYNC, task_id=custom_task_id, priority=DanswerCeleryPriority.LOW, @@ -168,6 +170,7 @@ def generate_tasks( db_session: Session, redis_client: Redis, lock: redis.lock.Lock, + tenant_id: str | None, ) -> int | None: last_lock_time = time.monotonic() @@ -204,7 +207,7 @@ def generate_tasks( result = celery_app.send_task( "vespa_metadata_sync_task", - kwargs=dict(document_id=doc.id), + kwargs=dict(document_id=doc.id, tenant_id=tenant_id), queue=DanswerCeleryQueues.VESPA_METADATA_SYNC, task_id=custom_task_id, priority=DanswerCeleryPriority.LOW, @@ -244,6 +247,7 @@ def generate_tasks( db_session: Session, redis_client: Redis, lock: redis.lock.Lock, + tenant_id: str | None, ) -> int | None: last_lock_time = time.monotonic() @@ -278,7 +282,7 @@ def generate_tasks( # Priority on sync's triggered by new indexing should be medium result = celery_app.send_task( "vespa_metadata_sync_task", - kwargs=dict(document_id=doc.id), + kwargs=dict(document_id=doc.id, tenant_id=tenant_id), queue=DanswerCeleryQueues.VESPA_METADATA_SYNC, task_id=custom_task_id, priority=DanswerCeleryPriority.MEDIUM, @@ -300,6 +304,7 @@ def generate_tasks( db_session: Session, redis_client: Redis, lock: redis.lock.Lock, + tenant_id: str | None, ) -> int | None: last_lock_time = time.monotonic() @@ -336,6 +341,7 @@ def generate_tasks( document_id=doc.id, connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id, + tenant_id=tenant_id, ), queue=DanswerCeleryQueues.CONNECTOR_DELETION, task_id=custom_task_id, @@ -409,6 +415,7 @@ def generate_tasks( db_session: Session, redis_client: Redis, lock: redis.lock.Lock | None, + tenant_id: str | None, ) -> int | None: last_lock_time = time.monotonic() @@ -442,6 +449,7 @@ def generate_tasks( document_id=doc_id, connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id, + tenant_id=tenant_id, ), queue=DanswerCeleryQueues.CONNECTOR_DELETION, task_id=custom_task_id, diff --git a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py index a3223aacc9f..6a4c4da8243 100644 --- a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py +++ b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py @@ -23,7 +23,7 @@ soft_time_limit=JOB_TIMEOUT, trail=False, ) -def check_for_connector_deletion_task() -> None: +def check_for_connector_deletion_task(tenant_id: str | None) -> None: r = get_redis_client() lock_beat = r.lock( @@ -40,7 +40,7 @@ def check_for_connector_deletion_task() -> None: cc_pairs = get_connector_credential_pairs(db_session) for cc_pair in cc_pairs: try_generate_document_cc_pair_cleanup_tasks( - cc_pair, db_session, r, lock_beat + cc_pair, db_session, r, lock_beat, tenant_id ) except SoftTimeLimitExceeded: task_logger.info( @@ -58,6 +58,7 @@ def try_generate_document_cc_pair_cleanup_tasks( db_session: Session, r: Redis, lock_beat: redis.lock.Lock, + tenant_id: str | None, ) -> int | None: """Returns an int if syncing is needed. The int represents the number of sync tasks generated. Note that syncing can still be required even if the number of sync tasks generated is zero. @@ -90,7 +91,9 @@ def try_generate_document_cc_pair_cleanup_tasks( task_logger.info( f"RedisConnectorDeletion.generate_tasks starting. cc_pair_id={cc_pair.id}" ) - tasks_generated = rcd.generate_tasks(celery_app, db_session, r, lock_beat) + tasks_generated = rcd.generate_tasks( + celery_app, db_session, r, lock_beat, tenant_id + ) if tasks_generated is None: return None diff --git a/backend/danswer/background/celery/tasks/pruning/tasks.py b/backend/danswer/background/celery/tasks/pruning/tasks.py index f72229b7d8c..28149bb82a3 100644 --- a/backend/danswer/background/celery/tasks/pruning/tasks.py +++ b/backend/danswer/background/celery/tasks/pruning/tasks.py @@ -24,17 +24,21 @@ from danswer.db.connector_credential_pair import get_connector_credential_pair from danswer.db.connector_credential_pair import get_connector_credential_pairs from danswer.db.document import get_documents_for_connector_credential_pair -from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.engine import get_session_with_tenant from danswer.db.enums import ConnectorCredentialPairStatus from danswer.db.models import ConnectorCredentialPair from danswer.redis.redis_pool import get_redis_client +from danswer.utils.logger import setup_logger + + +logger = setup_logger() @shared_task( name="check_for_prune_task_2", soft_time_limit=JOB_TIMEOUT, ) -def check_for_prune_task_2() -> None: +def check_for_prune_task_2(tenant_id: str | None) -> None: r = get_redis_client() lock_beat = r.lock( @@ -47,11 +51,11 @@ def check_for_prune_task_2() -> None: if not lock_beat.acquire(blocking=False): return - with Session(get_sqlalchemy_engine()) as db_session: + with get_session_with_tenant(tenant_id) as db_session: cc_pairs = get_connector_credential_pairs(db_session) for cc_pair in cc_pairs: tasks_created = ccpair_pruning_generator_task_creation_helper( - cc_pair, db_session, r, lock_beat + cc_pair, db_session, tenant_id, r, lock_beat ) if not tasks_created: continue @@ -71,6 +75,7 @@ def check_for_prune_task_2() -> None: def ccpair_pruning_generator_task_creation_helper( cc_pair: ConnectorCredentialPair, db_session: Session, + tenant_id: str | None, r: Redis, lock_beat: redis.lock.Lock, ) -> int | None: @@ -101,13 +106,14 @@ def ccpair_pruning_generator_task_creation_helper( if datetime.now(timezone.utc) < next_prune: return None - return try_creating_prune_generator_task(cc_pair, db_session, r) + return try_creating_prune_generator_task(cc_pair, db_session, r, tenant_id) def try_creating_prune_generator_task( cc_pair: ConnectorCredentialPair, db_session: Session, r: Redis, + tenant_id: str | None, ) -> int | None: """Checks for any conditions that should block the pruning generator task from being created, then creates the task. @@ -140,7 +146,9 @@ def try_creating_prune_generator_task( celery_app.send_task( "connector_pruning_generator_task", kwargs=dict( - connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id + connector_id=cc_pair.connector_id, + credential_id=cc_pair.credential_id, + tenant_id=tenant_id, ), queue=DanswerCeleryQueues.CONNECTOR_PRUNING, task_id=custom_task_id, @@ -153,14 +161,16 @@ def try_creating_prune_generator_task( @shared_task(name="connector_pruning_generator_task", soft_time_limit=JOB_TIMEOUT) -def connector_pruning_generator_task(connector_id: int, credential_id: int) -> None: +def connector_pruning_generator_task( + connector_id: int, credential_id: int, tenant_id: str | None +) -> None: """connector pruning task. For a cc pair, this task pulls all document IDs from the source and compares those IDs to locally stored documents and deletes all locally stored IDs missing from the most recently pulled document ID list""" r = get_redis_client() - with Session(get_sqlalchemy_engine()) as db_session: + with get_session_with_tenant(tenant_id) as db_session: try: cc_pair = get_connector_credential_pair( db_session=db_session, @@ -218,7 +228,9 @@ def redis_increment_callback(amount: int) -> None: task_logger.info( f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair.id}" ) - tasks_generated = rcp.generate_tasks(celery_app, db_session, r, None) + tasks_generated = rcp.generate_tasks( + celery_app, db_session, r, None, tenant_id + ) if tasks_generated is None: return None diff --git a/backend/danswer/background/celery/tasks/shared/tasks.py b/backend/danswer/background/celery/tasks/shared/tasks.py index 0977fb35d29..b065122be84 100644 --- a/backend/danswer/background/celery/tasks/shared/tasks.py +++ b/backend/danswer/background/celery/tasks/shared/tasks.py @@ -1,7 +1,6 @@ from celery import shared_task from celery import Task from celery.exceptions import SoftTimeLimitExceeded -from sqlalchemy.orm import Session from danswer.access.access import get_access_for_document from danswer.background.celery.celery_app import task_logger @@ -11,7 +10,7 @@ from danswer.db.document import get_document_connector_count from danswer.db.document import mark_document_as_synced from danswer.db.document_set import fetch_document_sets_for_document -from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.engine import get_session_with_tenant from danswer.document_index.document_index_utils import get_both_index_names from danswer.document_index.factory import get_default_document_index from danswer.document_index.interfaces import VespaDocumentFields @@ -26,7 +25,11 @@ max_retries=3, ) def document_by_cc_pair_cleanup_task( - self: Task, document_id: str, connector_id: int, credential_id: int + self: Task, + document_id: str, + connector_id: int, + credential_id: int, + tenant_id: str | None, ) -> bool: """A lightweight subtask used to clean up document to cc pair relationships. Created by connection deletion and connector pruning parent tasks.""" @@ -44,7 +47,7 @@ def document_by_cc_pair_cleanup_task( (6) delete all relevant entries from postgres """ try: - with Session(get_sqlalchemy_engine()) as db_session: + with get_session_with_tenant(tenant_id) as db_session: action = "skip" chunks_affected = 0 diff --git a/backend/danswer/background/celery/tasks/vespa/tasks.py b/backend/danswer/background/celery/tasks/vespa/tasks.py index 39b6f8a91e0..e6a017b7ac7 100644 --- a/backend/danswer/background/celery/tasks/vespa/tasks.py +++ b/backend/danswer/background/celery/tasks/vespa/tasks.py @@ -38,6 +38,7 @@ from danswer.db.document_set import fetch_document_sets_for_document from danswer.db.document_set import get_document_set_by_id from danswer.db.document_set import mark_document_set_as_synced +from danswer.db.engine import get_session_with_tenant from danswer.db.engine import get_sqlalchemy_engine from danswer.db.index_attempt import delete_index_attempts from danswer.db.models import DocumentSet @@ -61,7 +62,7 @@ soft_time_limit=JOB_TIMEOUT, trail=False, ) -def check_for_vespa_sync_task() -> None: +def check_for_vespa_sync_task(tenant_id: str | None) -> None: """Runs periodically to check if any document needs syncing. Generates sets of tasks for Celery if syncing is needed.""" @@ -77,8 +78,8 @@ def check_for_vespa_sync_task() -> None: if not lock_beat.acquire(blocking=False): return - with Session(get_sqlalchemy_engine()) as db_session: - try_generate_stale_document_sync_tasks(db_session, r, lock_beat) + with get_session_with_tenant(tenant_id) as db_session: + try_generate_stale_document_sync_tasks(db_session, r, lock_beat, tenant_id) # check if any document sets are not synced document_set_info = fetch_document_sets( @@ -86,7 +87,7 @@ def check_for_vespa_sync_task() -> None: ) for document_set, _ in document_set_info: try_generate_document_set_sync_tasks( - document_set, db_session, r, lock_beat + document_set, db_session, r, lock_beat, tenant_id ) # check if any user groups are not synced @@ -101,7 +102,7 @@ def check_for_vespa_sync_task() -> None: ) for usergroup in user_groups: try_generate_user_group_sync_tasks( - usergroup, db_session, r, lock_beat + usergroup, db_session, r, lock_beat, tenant_id ) except ModuleNotFoundError: # Always exceptions on the MIT version, which is expected @@ -120,7 +121,7 @@ def check_for_vespa_sync_task() -> None: def try_generate_stale_document_sync_tasks( - db_session: Session, r: Redis, lock_beat: redis.lock.Lock + db_session: Session, r: Redis, lock_beat: redis.lock.Lock, tenant_id: str | None ) -> int | None: # the fence is up, do nothing if r.exists(RedisConnectorCredentialPair.get_fence_key()): @@ -145,7 +146,9 @@ def try_generate_stale_document_sync_tasks( cc_pairs = get_connector_credential_pairs(db_session) for cc_pair in cc_pairs: rc = RedisConnectorCredentialPair(cc_pair.id) - tasks_generated = rc.generate_tasks(celery_app, db_session, r, lock_beat) + tasks_generated = rc.generate_tasks( + celery_app, db_session, r, lock_beat, tenant_id + ) if tasks_generated is None: continue @@ -169,7 +172,11 @@ def try_generate_stale_document_sync_tasks( def try_generate_document_set_sync_tasks( - document_set: DocumentSet, db_session: Session, r: Redis, lock_beat: redis.lock.Lock + document_set: DocumentSet, + db_session: Session, + r: Redis, + lock_beat: redis.lock.Lock, + tenant_id: str | None, ) -> int | None: lock_beat.reacquire() @@ -193,7 +200,9 @@ def try_generate_document_set_sync_tasks( ) # Add all documents that need to be updated into the queue - tasks_generated = rds.generate_tasks(celery_app, db_session, r, lock_beat) + tasks_generated = rds.generate_tasks( + celery_app, db_session, r, lock_beat, tenant_id + ) if tasks_generated is None: return None @@ -214,7 +223,11 @@ def try_generate_document_set_sync_tasks( def try_generate_user_group_sync_tasks( - usergroup: UserGroup, db_session: Session, r: Redis, lock_beat: redis.lock.Lock + usergroup: UserGroup, + db_session: Session, + r: Redis, + lock_beat: redis.lock.Lock, + tenant_id: str | None, ) -> int | None: lock_beat.reacquire() @@ -236,7 +249,9 @@ def try_generate_user_group_sync_tasks( task_logger.info( f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}" ) - tasks_generated = rug.generate_tasks(celery_app, db_session, r, lock_beat) + tasks_generated = rug.generate_tasks( + celery_app, db_session, r, lock_beat, tenant_id + ) if tasks_generated is None: return None @@ -471,7 +486,7 @@ def monitor_ccpair_pruning_taskset( @shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True) -def monitor_vespa_sync(self: Task) -> None: +def monitor_vespa_sync(self: Task, tenant_id: str | None) -> None: """This is a celery beat task that monitors and finalizes metadata sync tasksets. It scans for fence values and then gets the counts of any associated tasksets. If the count is 0, that means all tasks finished and we should clean up. @@ -516,7 +531,7 @@ def monitor_vespa_sync(self: Task) -> None: for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"): monitor_connector_deletion_taskset(key_bytes, r) - with Session(get_sqlalchemy_engine()) as db_session: + with get_session_with_tenant(tenant_id) as db_session: lock_beat.reacquire() for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"): monitor_document_set_taskset(key_bytes, r, db_session) @@ -556,11 +571,13 @@ def monitor_vespa_sync(self: Task) -> None: time_limit=60, max_retries=3, ) -def vespa_metadata_sync_task(self: Task, document_id: str) -> bool: +def vespa_metadata_sync_task( + self: Task, document_id: str, tenant_id: str | None +) -> bool: task_logger.info(f"document_id={document_id}") try: - with Session(get_sqlalchemy_engine()) as db_session: + with get_session_with_tenant(tenant_id) as db_session: curr_ind_name, sec_ind_name = get_both_index_names(db_session) document_index = get_default_document_index( primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name diff --git a/backend/danswer/background/indexing/run_indexing.py b/backend/danswer/background/indexing/run_indexing.py index b3d011a422b..d5e14675c65 100644 --- a/backend/danswer/background/indexing/run_indexing.py +++ b/backend/danswer/background/indexing/run_indexing.py @@ -4,6 +4,7 @@ from datetime import timedelta from datetime import timezone +from sqlalchemy import text from sqlalchemy.orm import Session from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt @@ -17,7 +18,7 @@ from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id from danswer.db.connector_credential_pair import get_last_successful_attempt_time from danswer.db.connector_credential_pair import update_connector_credential_pair -from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.engine import get_session_with_tenant from danswer.db.enums import ConnectorCredentialPairStatus from danswer.db.index_attempt import get_index_attempt from danswer.db.index_attempt import mark_attempt_failed @@ -46,6 +47,7 @@ def _get_connector_runner( attempt: IndexAttempt, start_time: datetime, end_time: datetime, + tenant_id: str | None, ) -> ConnectorRunner: """ NOTE: `start_time` and `end_time` are only used for poll connectors @@ -87,8 +89,7 @@ def _get_connector_runner( def _run_indexing( - db_session: Session, - index_attempt: IndexAttempt, + db_session: Session, index_attempt: IndexAttempt, tenant_id: str | None ) -> None: """ 1. Get documents which are either new or updated from specified application @@ -129,6 +130,7 @@ def _run_indexing( or (search_settings.status == IndexModelStatus.FUTURE) ), db_session=db_session, + tenant_id=tenant_id, ) db_cc_pair = index_attempt.connector_credential_pair @@ -185,6 +187,7 @@ def _run_indexing( attempt=index_attempt, start_time=window_start, end_time=window_end, + tenant_id=tenant_id, ) all_connector_doc_ids: set[str] = set() @@ -212,7 +215,9 @@ def _run_indexing( db_session.refresh(index_attempt) if index_attempt.status != IndexingStatus.IN_PROGRESS: # Likely due to user manually disabling it or model swap - raise RuntimeError("Index Attempt was canceled") + raise RuntimeError( + f"Index Attempt was canceled, status is {index_attempt.status}" + ) batch_description = [] for doc in doc_batch: @@ -373,12 +378,21 @@ def _run_indexing( ) -def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexAttempt: +def _prepare_index_attempt( + db_session: Session, index_attempt_id: int, tenant_id: str | None +) -> IndexAttempt: # make sure that the index attempt can't change in between checking the # status and marking it as in_progress. This setting will be discarded # after the next commit: # https://docs.sqlalchemy.org/en/20/orm/session_transaction.html#setting-isolation-for-individual-transactions db_session.connection(execution_options={"isolation_level": "SERIALIZABLE"}) # type: ignore + if tenant_id is not None: + # Explicitly set the search path for the given tenant + db_session.execute(text(f'SET search_path TO "{tenant_id}"')) + # Verify the search path was set correctly + result = db_session.execute(text("SHOW search_path")) + current_search_path = result.scalar() + logger.info(f"Current search path set to: {current_search_path}") attempt = get_index_attempt( db_session=db_session, @@ -401,12 +415,11 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA def run_indexing_entrypoint( - index_attempt_id: int, connector_credential_pair_id: int, is_ee: bool = False + index_attempt_id: int, + tenant_id: str | None, + connector_credential_pair_id: int, + is_ee: bool = False, ) -> None: - """Entrypoint for indexing run when using dask distributed. - Wraps the actual logic in a `try` block so that we can catch any exceptions - and mark the attempt as failed.""" - try: if is_ee: global_version.set_ee() @@ -416,26 +429,29 @@ def run_indexing_entrypoint( IndexAttemptSingleton.set_cc_and_index_id( index_attempt_id, connector_credential_pair_id ) - - with Session(get_sqlalchemy_engine()) as db_session: - # make sure that it is valid to run this indexing attempt + mark it - # as in progress - attempt = _prepare_index_attempt(db_session, index_attempt_id) + with get_session_with_tenant(tenant_id) as db_session: + attempt = _prepare_index_attempt(db_session, index_attempt_id, tenant_id) logger.info( - f"Indexing starting: " - f"connector='{attempt.connector_credential_pair.connector.name}' " + f"Indexing starting for tenant {tenant_id}: " + if tenant_id is not None + else "" + + f"connector='{attempt.connector_credential_pair.connector.name}' " f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' " f"credentials='{attempt.connector_credential_pair.connector_id}'" ) - _run_indexing(db_session, attempt) + _run_indexing(db_session, attempt, tenant_id) logger.info( - f"Indexing finished: " - f"connector='{attempt.connector_credential_pair.connector.name}' " + f"Indexing finished for tenant {tenant_id}: " + if tenant_id is not None + else "" + + f"connector='{attempt.connector_credential_pair.connector.name}' " f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' " f"credentials='{attempt.connector_credential_pair.connector_id}'" ) except Exception as e: - logger.exception(f"Indexing job with ID '{index_attempt_id}' failed due to {e}") + logger.exception( + f"Indexing job with ID '{index_attempt_id}' for tenant {tenant_id} failed due to {e}" + ) diff --git a/backend/danswer/background/update.py b/backend/danswer/background/update.py index 773165c5161..f7a00687c43 100755 --- a/backend/danswer/background/update.py +++ b/backend/danswer/background/update.py @@ -6,6 +6,8 @@ from dask.distributed import Client from dask.distributed import Future from distributed import LocalCluster +from sqlalchemy import text +from sqlalchemy.exc import ProgrammingError from sqlalchemy.orm import Session from danswer.background.indexing.dask_utils import ResourceLogger @@ -15,14 +17,16 @@ from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP +from danswer.configs.app_configs import MULTI_TENANT from danswer.configs.app_configs import NUM_INDEXING_WORKERS from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS from danswer.configs.constants import DocumentSource from danswer.configs.constants import POSTGRES_INDEXER_APP_NAME +from danswer.configs.constants import TENANT_ID_PREFIX from danswer.db.connector import fetch_connectors from danswer.db.connector_credential_pair import fetch_connector_credential_pairs from danswer.db.engine import get_db_current_time -from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.engine import get_session_with_tenant from danswer.db.engine import SqlEngine from danswer.db.index_attempt import create_index_attempt from danswer.db.index_attempt import get_index_attempt @@ -153,13 +157,15 @@ def _mark_run_failed( """Main funcs""" -def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None: +def create_indexing_jobs( + existing_jobs: dict[int, Future | SimpleJob], tenant_id: str | None +) -> None: """Creates new indexing jobs for each connector / credential pair which is: 1. Enabled 2. `refresh_frequency` time has passed since the last indexing run for this pair 3. There is not already an ongoing indexing attempt for this pair """ - with Session(get_sqlalchemy_engine()) as db_session: + with get_session_with_tenant(tenant_id) as db_session: ongoing: set[tuple[int | None, int]] = set() for attempt_id in existing_jobs: attempt = get_index_attempt( @@ -214,11 +220,12 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None: def cleanup_indexing_jobs( existing_jobs: dict[int, Future | SimpleJob], + tenant_id: str | None, timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT, ) -> dict[int, Future | SimpleJob]: existing_jobs_copy = existing_jobs.copy() # clean up completed jobs - with Session(get_sqlalchemy_engine()) as db_session: + with get_session_with_tenant(tenant_id) as db_session: for attempt_id, job in existing_jobs.items(): index_attempt = get_index_attempt( db_session=db_session, index_attempt_id=attempt_id @@ -256,38 +263,41 @@ def cleanup_indexing_jobs( ) # clean up in-progress jobs that were never completed - connectors = fetch_connectors(db_session) - for connector in connectors: - in_progress_indexing_attempts = get_inprogress_index_attempts( - connector.id, db_session - ) - for index_attempt in in_progress_indexing_attempts: - if index_attempt.id in existing_jobs: - # If index attempt is canceled, stop the run - if index_attempt.status == IndexingStatus.FAILED: - existing_jobs[index_attempt.id].cancel() - # check to see if the job has been updated in last `timeout_hours` hours, if not - # assume it to frozen in some bad state and just mark it as failed. Note: this relies - # on the fact that the `time_updated` field is constantly updated every - # batch of documents indexed - current_db_time = get_db_current_time(db_session=db_session) - time_since_update = current_db_time - index_attempt.time_updated - if time_since_update.total_seconds() > 60 * 60 * timeout_hours: - existing_jobs[index_attempt.id].cancel() + try: + connectors = fetch_connectors(db_session) + for connector in connectors: + in_progress_indexing_attempts = get_inprogress_index_attempts( + connector.id, db_session + ) + + for index_attempt in in_progress_indexing_attempts: + if index_attempt.id in existing_jobs: + # If index attempt is canceled, stop the run + if index_attempt.status == IndexingStatus.FAILED: + existing_jobs[index_attempt.id].cancel() + # check to see if the job has been updated in last `timeout_hours` hours, if not + # assume it to frozen in some bad state and just mark it as failed. Note: this relies + # on the fact that the `time_updated` field is constantly updated every + # batch of documents indexed + current_db_time = get_db_current_time(db_session=db_session) + time_since_update = current_db_time - index_attempt.time_updated + if time_since_update.total_seconds() > 60 * 60 * timeout_hours: + existing_jobs[index_attempt.id].cancel() + _mark_run_failed( + db_session=db_session, + index_attempt=index_attempt, + failure_reason="Indexing run frozen - no updates in the last three hours. " + "The run will be re-attempted at next scheduled indexing time.", + ) + else: + # If job isn't known, simply mark it as failed _mark_run_failed( db_session=db_session, index_attempt=index_attempt, - failure_reason="Indexing run frozen - no updates in the last three hours. " - "The run will be re-attempted at next scheduled indexing time.", + failure_reason=_UNEXPECTED_STATE_FAILURE_REASON, ) - else: - # If job isn't known, simply mark it as failed - _mark_run_failed( - db_session=db_session, - index_attempt=index_attempt, - failure_reason=_UNEXPECTED_STATE_FAILURE_REASON, - ) - + except ProgrammingError: + logger.debug(f"No Connector Table exists for: {tenant_id}") return existing_jobs_copy @@ -295,13 +305,15 @@ def kickoff_indexing_jobs( existing_jobs: dict[int, Future | SimpleJob], client: Client | SimpleJobClient, secondary_client: Client | SimpleJobClient, + tenant_id: str | None, ) -> dict[int, Future | SimpleJob]: existing_jobs_copy = existing_jobs.copy() - engine = get_sqlalchemy_engine() + + current_session = get_session_with_tenant(tenant_id) # Don't include jobs waiting in the Dask queue that just haven't started running # Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet - with Session(engine) as db_session: + with current_session as db_session: # get_not_started_index_attempts orders its returned results from oldest to newest # we must process attempts in a FIFO manner to prevent connector starvation new_indexing_attempts = [ @@ -332,7 +344,7 @@ def kickoff_indexing_jobs( logger.warning( f"Skipping index attempt as Connector has been deleted: {attempt}" ) - with Session(engine) as db_session: + with current_session as db_session: mark_attempt_failed( attempt, db_session, failure_reason="Connector is null" ) @@ -341,7 +353,7 @@ def kickoff_indexing_jobs( logger.warning( f"Skipping index attempt as Credential has been deleted: {attempt}" ) - with Session(engine) as db_session: + with current_session as db_session: mark_attempt_failed( attempt, db_session, failure_reason="Credential is null" ) @@ -352,6 +364,7 @@ def kickoff_indexing_jobs( run = client.submit( run_indexing_entrypoint, attempt.id, + tenant_id, attempt.connector_credential_pair_id, global_version.is_ee_version(), pure=False, @@ -363,6 +376,7 @@ def kickoff_indexing_jobs( run = secondary_client.submit( run_indexing_entrypoint, attempt.id, + tenant_id, attempt.connector_credential_pair_id, global_version.is_ee_version(), pure=False, @@ -398,42 +412,40 @@ def kickoff_indexing_jobs( return existing_jobs_copy +def get_all_tenant_ids() -> list[str] | list[None]: + if not MULTI_TENANT: + return [None] + with get_session_with_tenant(tenant_id="public") as session: + result = session.execute( + text( + """ + SELECT schema_name + FROM information_schema.schemata + WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'public')""" + ) + ) + tenant_ids = [row[0] for row in result] + + valid_tenants = [ + tenant + for tenant in tenant_ids + if tenant is None or tenant.startswith(TENANT_ID_PREFIX) + ] + + return valid_tenants + + def update_loop( delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS, num_secondary_workers: int = NUM_SECONDARY_INDEXING_WORKERS, ) -> None: - engine = get_sqlalchemy_engine() - with Session(engine) as db_session: - check_index_swap(db_session=db_session) - search_settings = get_current_search_settings(db_session) - - # So that the first time users aren't surprised by really slow speed of first - # batch of documents indexed - - if search_settings.provider_type is None: - logger.notice("Running a first inference to warm up embedding model") - embedding_model = EmbeddingModel.from_db_model( - search_settings=search_settings, - server_host=INDEXING_MODEL_SERVER_HOST, - server_port=MODEL_SERVER_PORT, - ) - - warm_up_bi_encoder( - embedding_model=embedding_model, - ) - logger.notice("First inference complete.") - client_primary: Client | SimpleJobClient client_secondary: Client | SimpleJobClient if DASK_JOB_CLIENT_ENABLED: cluster_primary = LocalCluster( n_workers=num_workers, threads_per_worker=1, - # there are warning about high memory usage + "Event loop unresponsive" - # which are not relevant to us since our workers are expected to use a - # lot of memory + involve CPU intensive tasks that will not relinquish - # the event loop silence_logs=logging.ERROR, ) cluster_secondary = LocalCluster( @@ -449,7 +461,7 @@ def update_loop( client_primary = SimpleJobClient(n_workers=num_workers) client_secondary = SimpleJobClient(n_workers=num_secondary_workers) - existing_jobs: dict[int, Future | SimpleJob] = {} + existing_jobs: dict[str | None, dict[int, Future | SimpleJob]] = {} logger.notice("Startup complete. Waiting for indexing jobs...") while True: @@ -458,24 +470,58 @@ def update_loop( logger.debug(f"Running update, current UTC time: {start_time_utc}") if existing_jobs: - # TODO: make this debug level once the "no jobs are being scheduled" issue is resolved logger.debug( "Found existing indexing jobs: " - f"{[(attempt_id, job.status) for attempt_id, job in existing_jobs.items()]}" + f"{[(tenant_id, list(jobs.keys())) for tenant_id, jobs in existing_jobs.items()]}" ) try: - with Session(get_sqlalchemy_engine()) as db_session: - check_index_swap(db_session) - existing_jobs = cleanup_indexing_jobs(existing_jobs=existing_jobs) - create_indexing_jobs(existing_jobs=existing_jobs) - existing_jobs = kickoff_indexing_jobs( - existing_jobs=existing_jobs, - client=client_primary, - secondary_client=client_secondary, - ) + tenants = get_all_tenant_ids() + + for tenant_id in tenants: + try: + logger.debug( + f"Processing {'index attempts' if tenant_id is None else f'tenant {tenant_id}'}" + ) + with get_session_with_tenant(tenant_id) as db_session: + check_index_swap(db_session=db_session) + if not MULTI_TENANT: + search_settings = get_current_search_settings(db_session) + if search_settings.provider_type is None: + logger.notice( + "Running a first inference to warm up embedding model" + ) + embedding_model = EmbeddingModel.from_db_model( + search_settings=search_settings, + server_host=INDEXING_MODEL_SERVER_HOST, + server_port=MODEL_SERVER_PORT, + ) + warm_up_bi_encoder(embedding_model=embedding_model) + logger.notice("First inference complete.") + + tenant_jobs = existing_jobs.get(tenant_id, {}) + + tenant_jobs = cleanup_indexing_jobs( + existing_jobs=tenant_jobs, tenant_id=tenant_id + ) + create_indexing_jobs(existing_jobs=tenant_jobs, tenant_id=tenant_id) + tenant_jobs = kickoff_indexing_jobs( + existing_jobs=tenant_jobs, + client=client_primary, + secondary_client=client_secondary, + tenant_id=tenant_id, + ) + + existing_jobs[tenant_id] = tenant_jobs + + except Exception as e: + logger.exception( + f"Failed to process tenant {tenant_id or 'default'}: {e}" + ) + except Exception as e: logger.exception(f"Failed to run update due to {e}") + sleep_time = delay - (time.time() - start) if sleep_time > 0: time.sleep(sleep_time) diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index eaa231e88b7..04925262196 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -429,3 +429,5 @@ DATA_PLANE_SECRET = os.environ.get("DATA_PLANE_SECRET", "") EXPECTED_API_KEY = os.environ.get("EXPECTED_API_KEY", "") + +ENABLE_EMAIL_INVITES = os.environ.get("ENABLE_EMAIL_INVITES", "").lower() == "true" diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index 4c43dfcf634..e4aeb88c279 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -31,6 +31,9 @@ "You can still use Danswer as a search engine." ) +# Prefix used for all tenant ids +TENANT_ID_PREFIX = "tenant_" + # Postgres connection constants for application_name POSTGRES_WEB_APP_NAME = "web" POSTGRES_INDEXER_APP_NAME = "indexer" diff --git a/backend/danswer/db/auth.py b/backend/danswer/db/auth.py index 6d150b106cb..dc3f5a837bd 100644 --- a/backend/danswer/db/auth.py +++ b/backend/danswer/db/auth.py @@ -13,7 +13,7 @@ from danswer.auth.schemas import UserRole from danswer.db.engine import get_async_session -from danswer.db.engine import get_sqlalchemy_async_engine +from danswer.db.engine import get_async_session_with_tenant from danswer.db.models import AccessToken from danswer.db.models import OAuthAccount from danswer.db.models import User @@ -34,7 +34,7 @@ def get_default_admin_user_emails() -> list[str]: async def get_user_count() -> int: - async with AsyncSession(get_sqlalchemy_async_engine()) as asession: + async with get_async_session_with_tenant() as asession: stmt = select(func.count(User.id)) result = await asession.execute(stmt) user_count = result.scalar() diff --git a/backend/danswer/db/connector_credential_pair.py b/backend/danswer/db/connector_credential_pair.py index f9d79df96ae..b3e1de7647a 100644 --- a/backend/danswer/db/connector_credential_pair.py +++ b/backend/danswer/db/connector_credential_pair.py @@ -390,6 +390,7 @@ def add_credential_to_connector( ) db_session.add(association) db_session.flush() # make sure the association has an id + db_session.refresh(association) if groups and access_type != AccessType.SYNC: _relate_groups_to_cc_pair__no_commit( diff --git a/backend/danswer/db/engine.py b/backend/danswer/db/engine.py index af7aad23669..a1f2335d348 100644 --- a/backend/danswer/db/engine.py +++ b/backend/danswer/db/engine.py @@ -1,16 +1,16 @@ import contextlib -import contextvars import re import threading import time from collections.abc import AsyncGenerator from collections.abc import Generator +from contextlib import asynccontextmanager +from contextlib import contextmanager from datetime import datetime from typing import Any from typing import ContextManager import jwt -from fastapi import Depends from fastapi import HTTPException from fastapi import Request from sqlalchemy import event @@ -39,7 +39,7 @@ from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME from danswer.utils.logger import setup_logger - +from shared_configs.configs import current_tenant_id logger = setup_logger() @@ -230,18 +230,8 @@ def get_sqlalchemy_async_engine() -> AsyncEngine: return _ASYNC_ENGINE -# Context variable to store the current tenant ID -# This allows us to maintain tenant-specific context throughout the request lifecycle -# The default value is set to POSTGRES_DEFAULT_SCHEMA for non-multi-tenant setups -# This context variable works in both synchronous and asynchronous contexts -# In async code, it's automatically carried across coroutines -# In sync code, it's managed per thread -current_tenant_id = contextvars.ContextVar( - "current_tenant_id", default=POSTGRES_DEFAULT_SCHEMA -) - - -# Dependency to get the current tenant ID and set the context variable +# Dependency to get the current tenant ID +# If no token is present, uses the default schema for this use case def get_current_tenant_id(request: Request) -> str: """Dependency that extracts the tenant ID from the JWT token in the request and sets the context variable.""" if not MULTI_TENANT: @@ -251,32 +241,31 @@ def get_current_tenant_id(request: Request) -> str: token = request.cookies.get("tenant_details") if not token: + current_value = current_tenant_id.get() # If no token is present, use the default schema or handle accordingly - tenant_id = POSTGRES_DEFAULT_SCHEMA - current_tenant_id.set(tenant_id) - return tenant_id + return current_value try: payload = jwt.decode(token, SECRET_JWT_KEY, algorithms=["HS256"]) tenant_id = payload.get("tenant_id") if not tenant_id: - raise HTTPException( - status_code=400, detail="Invalid token: tenant_id missing" - ) + return current_tenant_id.get() if not is_valid_schema_name(tenant_id): - raise ValueError("Invalid tenant ID format") + raise HTTPException(status_code=400, detail="Invalid tenant ID format") current_tenant_id.set(tenant_id) + return tenant_id except jwt.InvalidTokenError: - raise HTTPException(status_code=401, detail="Invalid token format") - except ValueError as e: - # Let the 400 error bubble up - raise HTTPException(status_code=400, detail=str(e)) - except Exception: + return current_tenant_id.get() + except Exception as e: + logger.error(f"Unexpected error in get_current_tenant_id: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error") -def get_session_with_tenant(tenant_id: str | None = None) -> Session: +@asynccontextmanager +async def get_async_session_with_tenant( + tenant_id: str | None = None, +) -> AsyncGenerator[AsyncSession, None]: if tenant_id is None: tenant_id = current_tenant_id.get() @@ -284,20 +273,78 @@ def get_session_with_tenant(tenant_id: str | None = None) -> Session: logger.error(f"Invalid tenant ID: {tenant_id}") raise Exception("Invalid tenant ID") - engine = SqlEngine.get_engine() - session = Session(engine, expire_on_commit=False) + engine = get_sqlalchemy_async_engine() + async_session_factory = sessionmaker( + bind=engine, expire_on_commit=False, class_=AsyncSession + ) # type: ignore - @event.listens_for(session, "after_begin") - def set_search_path(session: Session, transaction: Any, connection: Any) -> None: - connection.execute(text("SET search_path TO :schema"), {"schema": tenant_id}) + async with async_session_factory() as session: + try: + # Set the search_path to the tenant's schema + await session.execute(text(f'SET search_path = "{tenant_id}"')) + except Exception as e: + logger.error(f"Error setting search_path: {str(e)}") + # You can choose to re-raise the exception or handle it + # Here, we'll re-raise to prevent proceeding with an incorrect session + raise + else: + yield session + + +@contextmanager +def get_session_with_tenant( + tenant_id: str | None = None, +) -> Generator[Session, None, None]: + """Generate a database session with the appropriate tenant schema set.""" + engine = get_sqlalchemy_engine() + if tenant_id is None: + tenant_id = current_tenant_id.get() - return session + if not is_valid_schema_name(tenant_id): + raise HTTPException(status_code=400, detail="Invalid tenant ID") + + # Establish a raw connection without starting a transaction + with engine.connect() as connection: + # Access the raw DBAPI connection + dbapi_connection = connection.connection + + # Execute SET search_path outside of any transaction + cursor = dbapi_connection.cursor() + try: + cursor.execute(f'SET search_path TO "{tenant_id}"') + # Optionally verify the search_path was set correctly + cursor.execute("SHOW search_path") + cursor.fetchone() + finally: + cursor.close() + + # Proceed to create a session using the connection + with Session(bind=connection, expire_on_commit=False) as session: + try: + yield session + finally: + # Reset search_path to default after the session is used + if MULTI_TENANT: + cursor = dbapi_connection.cursor() + try: + cursor.execute('SET search_path TO "$user", public') + finally: + cursor.close() + + +def get_session_generator_with_tenant( + tenant_id: str | None = None, +) -> Generator[Session, None, None]: + with get_session_with_tenant(tenant_id) as session: + yield session -def get_session( - tenant_id: str = Depends(get_current_tenant_id), -) -> Generator[Session, None, None]: +def get_session() -> Generator[Session, None, None]: """Generate a database session with the appropriate tenant schema set.""" + tenant_id = current_tenant_id.get() + if tenant_id == "public" and MULTI_TENANT: + raise HTTPException(status_code=401, detail="User must authenticate") + engine = get_sqlalchemy_engine() with Session(engine, expire_on_commit=False) as session: if MULTI_TENANT: @@ -308,10 +355,9 @@ def get_session( yield session -async def get_async_session( - tenant_id: str = Depends(get_current_tenant_id), -) -> AsyncGenerator[AsyncSession, None]: +async def get_async_session() -> AsyncGenerator[AsyncSession, None]: """Generate an async database session with the appropriate tenant schema set.""" + tenant_id = current_tenant_id.get() engine = get_sqlalchemy_async_engine() async with AsyncSession(engine, expire_on_commit=False) as async_session: if MULTI_TENANT: @@ -324,7 +370,7 @@ async def get_async_session( def get_session_context_manager() -> ContextManager[Session]: """Context manager for database sessions.""" - return contextlib.contextmanager(get_session)() + return contextlib.contextmanager(get_session_generator_with_tenant)() def get_session_factory() -> sessionmaker[Session]: diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 392c7a28b2e..fc7dad7793f 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -1763,3 +1763,23 @@ class UsageReport(Base): requestor = relationship("User") file = relationship("PGFileStore") + + +""" +Multi-tenancy related tables +""" + + +class PublicBase(DeclarativeBase): + __abstract__ = True + + +class UserTenantMapping(Base): + __tablename__ = "user_tenant_mapping" + __table_args__ = ( + UniqueConstraint("email", "tenant_id", name="uq_user_tenant"), + {"schema": "public"}, + ) + + email: Mapped[str] = mapped_column(String, nullable=False, primary_key=True) + tenant_id: Mapped[str] = mapped_column(String, nullable=False) diff --git a/backend/danswer/indexing/indexing_pipeline.py b/backend/danswer/indexing/indexing_pipeline.py index 992bce2dccf..d40bd341fdf 100644 --- a/backend/danswer/indexing/indexing_pipeline.py +++ b/backend/danswer/indexing/indexing_pipeline.py @@ -137,6 +137,7 @@ def index_doc_batch_with_handler( attempt_id: int | None, db_session: Session, ignore_time_skip: bool = False, + tenant_id: str | None = None, ) -> tuple[int, int]: r = (0, 0) try: @@ -148,6 +149,7 @@ def index_doc_batch_with_handler( index_attempt_metadata=index_attempt_metadata, db_session=db_session, ignore_time_skip=ignore_time_skip, + tenant_id=tenant_id, ) except Exception as e: if INDEXING_EXCEPTION_LIMIT == 0: @@ -261,6 +263,7 @@ def index_doc_batch( index_attempt_metadata: IndexAttemptMetadata, db_session: Session, ignore_time_skip: bool = False, + tenant_id: str | None = None, ) -> tuple[int, int]: """Takes different pieces of the indexing pipeline and applies it to a batch of documents Note that the documents should already be batched at this point so that it does not inflate the @@ -324,6 +327,7 @@ def index_doc_batch( if chunk.source_document.id in ctx.id_to_db_doc_map else DEFAULT_BOOST ), + tenant_id=tenant_id, ) for chunk in chunks_with_embeddings ] @@ -373,6 +377,7 @@ def build_indexing_pipeline( chunker: Chunker | None = None, ignore_time_skip: bool = False, attempt_id: int | None = None, + tenant_id: str | None = None, ) -> IndexingPipelineProtocol: """Builds a pipeline which takes in a list (batch) of docs and indexes them.""" search_settings = get_current_search_settings(db_session) @@ -416,4 +421,5 @@ def build_indexing_pipeline( ignore_time_skip=ignore_time_skip, attempt_id=attempt_id, db_session=db_session, + tenant_id=tenant_id, ) diff --git a/backend/danswer/indexing/models.py b/backend/danswer/indexing/models.py index c789a2b351b..39cfa2cca0c 100644 --- a/backend/danswer/indexing/models.py +++ b/backend/danswer/indexing/models.py @@ -75,6 +75,7 @@ class DocMetadataAwareIndexChunk(IndexChunk): negative -> ranked lower. """ + tenant_id: str | None = None access: "DocumentAccess" document_sets: set[str] boost: int @@ -86,6 +87,7 @@ def from_index_chunk( access: "DocumentAccess", document_sets: set[str], boost: int, + tenant_id: str | None, ) -> "DocMetadataAwareIndexChunk": index_chunk_data = index_chunk.model_dump() return cls( @@ -93,6 +95,7 @@ def from_index_chunk( access=access, document_sets=document_sets, boost=boost, + tenant_id=tenant_id, ) diff --git a/backend/danswer/key_value_store/store.py b/backend/danswer/key_value_store/store.py index 4306743f875..240ff355b5b 100644 --- a/backend/danswer/key_value_store/store.py +++ b/backend/danswer/key_value_store/store.py @@ -3,15 +3,21 @@ from contextlib import contextmanager from typing import cast +from fastapi import HTTPException +from sqlalchemy import text from sqlalchemy.orm import Session +from danswer.configs.app_configs import MULTI_TENANT from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.engine import is_valid_schema_name from danswer.db.models import KVStore from danswer.key_value_store.interface import JSON_ro from danswer.key_value_store.interface import KeyValueStore from danswer.key_value_store.interface import KvKeyNotFoundError from danswer.redis.redis_pool import get_redis_client from danswer.utils.logger import setup_logger +from shared_configs.configs import current_tenant_id + logger = setup_logger() @@ -28,6 +34,16 @@ def __init__(self) -> None: def get_session(self) -> Iterator[Session]: engine = get_sqlalchemy_engine() with Session(engine, expire_on_commit=False) as session: + if MULTI_TENANT: + tenant_id = current_tenant_id.get() + if tenant_id == "public": + raise HTTPException( + status_code=401, detail="User must authenticate" + ) + if not is_valid_schema_name(tenant_id): + raise HTTPException(status_code=400, detail="Invalid tenant ID") + # Set the search_path to the tenant's schema + session.execute(text(f'SET search_path = "{tenant_id}"')) yield session def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None: diff --git a/backend/danswer/main.py b/backend/danswer/main.py index d7ac6b3c3ed..cd0c5c195a6 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -29,6 +29,7 @@ from danswer.configs.app_configs import AUTH_TYPE from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.app_configs import LOG_ENDPOINT_LATENCY +from danswer.configs.app_configs import MULTI_TENANT from danswer.configs.app_configs import OAUTH_CLIENT_ID from danswer.configs.app_configs import OAUTH_CLIENT_SECRET from danswer.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW @@ -157,6 +158,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: verify_auth = fetch_versioned_implementation( "danswer.auth.users", "verify_auth_setting" ) + # Will throw exception if an issue is found verify_auth() @@ -169,11 +171,13 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: # fill up Postgres connection pools await warm_up_connections() - # We cache this at the beginning so there is no delay in the first telemetry - get_or_generate_uuid() + if not MULTI_TENANT: + # We cache this at the beginning so there is no delay in the first telemetry + get_or_generate_uuid() - with Session(engine) as db_session: - setup_danswer(db_session) + # If we are multi-tenant, we need to only set up initial public tables + with Session(engine) as db_session: + setup_danswer(db_session) optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__}) yield diff --git a/backend/danswer/server/documents/cc_pair.py b/backend/danswer/server/documents/cc_pair.py index ce3131d050f..ea513b5c21d 100644 --- a/backend/danswer/server/documents/cc_pair.py +++ b/backend/danswer/server/documents/cc_pair.py @@ -22,6 +22,7 @@ update_connector_credential_pair_from_id, ) from danswer.db.document import get_document_counts_for_cc_pairs +from danswer.db.engine import current_tenant_id from danswer.db.engine import get_session from danswer.db.enums import AccessType from danswer.db.enums import ConnectorCredentialPairStatus @@ -257,7 +258,9 @@ def prune_cc_pair( f"credential_id={cc_pair.credential_id} " f"{cc_pair.connector.name} connector." ) - tasks_created = try_creating_prune_generator_task(cc_pair, db_session, r) + tasks_created = try_creating_prune_generator_task( + cc_pair, db_session, r, current_tenant_id.get() + ) if not tasks_created: raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, @@ -342,7 +345,7 @@ def sync_cc_pair( logger.info(f"Syncing the {cc_pair.connector.name} connector.") sync_external_doc_permissions_task.apply_async( - kwargs=dict(cc_pair_id=cc_pair_id), + kwargs=dict(cc_pair_id=cc_pair_id, tenant_id=current_tenant_id.get()), ) return StatusResponse( diff --git a/backend/danswer/server/manage/administrative.py b/backend/danswer/server/manage/administrative.py index 9c87e60b18a..7771c1ed824 100644 --- a/backend/danswer/server/manage/administrative.py +++ b/backend/danswer/server/manage/administrative.py @@ -20,6 +20,7 @@ update_connector_credential_pair_from_id, ) from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed +from danswer.db.engine import get_current_tenant_id from danswer.db.engine import get_session from danswer.db.enums import ConnectorCredentialPairStatus from danswer.db.feedback import fetch_docs_ranked_by_boost @@ -146,6 +147,7 @@ def create_deletion_attempt_for_connector_id( connector_credential_pair_identifier: ConnectorCredentialPairIdentifier, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), + tenant_id: str = Depends(get_current_tenant_id), ) -> None: connector_id = connector_credential_pair_identifier.connector_id credential_id = connector_credential_pair_identifier.credential_id @@ -196,6 +198,7 @@ def create_deletion_attempt_for_connector_id( celery_app.send_task( "check_for_connector_deletion_task", priority=DanswerCeleryPriority.HIGH, + kwargs={"tenant_id": tenant_id}, ) if cc_pair.connector.source == DocumentSource.FILE: diff --git a/backend/danswer/server/manage/users.py b/backend/danswer/server/manage/users.py index 2a43542460a..0614a4beb85 100644 --- a/backend/danswer/server/manage/users.py +++ b/backend/danswer/server/manage/users.py @@ -2,17 +2,21 @@ from datetime import datetime from datetime import timezone +import jwt from email_validator import validate_email from fastapi import APIRouter from fastapi import Body from fastapi import Depends from fastapi import HTTPException +from fastapi import Request from fastapi import status +from psycopg2.errors import UniqueViolation from pydantic import BaseModel from sqlalchemy import Column from sqlalchemy import desc from sqlalchemy import select from sqlalchemy import update +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session from danswer.auth.invited_users import get_invited_users @@ -26,9 +30,12 @@ from danswer.auth.users import current_user from danswer.auth.users import optional_user from danswer.configs.app_configs import AUTH_TYPE +from danswer.configs.app_configs import ENABLE_EMAIL_INVITES +from danswer.configs.app_configs import MULTI_TENANT from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS from danswer.configs.app_configs import VALID_EMAIL_DOMAINS from danswer.configs.constants import AuthType +from danswer.db.engine import current_tenant_id from danswer.db.engine import get_session from danswer.db.models import AccessToken from danswer.db.models import DocumentSet__User @@ -48,10 +55,13 @@ from danswer.server.models import FullUserSnapshot from danswer.server.models import InvitedUserSnapshot from danswer.server.models import MinimalUserSnapshot +from danswer.server.utils import send_user_email_invite from danswer.utils.logger import setup_logger from ee.danswer.db.api_key import is_api_key_email_address from ee.danswer.db.external_perm import delete_user__ext_group_for_user__no_commit from ee.danswer.db.user_group import remove_curator_status__no_commit +from ee.danswer.server.tenants.provisioning import add_users_to_tenant +from ee.danswer.server.tenants.provisioning import remove_users_from_tenant logger = setup_logger() @@ -171,12 +181,33 @@ def bulk_invite_users( raise HTTPException( status_code=400, detail="Auth is disabled, cannot invite users" ) + tenant_id = current_tenant_id.get() normalized_emails = [] for email in emails: email_info = validate_email(email) # can raise EmailNotValidError normalized_emails.append(email_info.normalized) # type: ignore + + if MULTI_TENANT: + try: + add_users_to_tenant(normalized_emails, tenant_id) + except IntegrityError as e: + if isinstance(e.orig, UniqueViolation): + raise HTTPException( + status_code=400, + detail="User has already been invited to a Danswer organization", + ) + raise + all_emails = list(set(normalized_emails) | set(get_invited_users())) + + if MULTI_TENANT and ENABLE_EMAIL_INVITES: + try: + for email in all_emails: + send_user_email_invite(email, current_user) + except Exception as e: + logger.error(f"Error sending email invite to invited users: {e}") + return write_invited_users(all_emails) @@ -187,6 +218,10 @@ def remove_invited_user( ) -> int: user_emails = get_invited_users() remaining_users = [user for user in user_emails if user != user_email.user_email] + + tenant_id = current_tenant_id.get() + remove_users_from_tenant([user_email.user_email], tenant_id) + return write_invited_users(remaining_users) @@ -330,6 +365,35 @@ async def get_user_role(user: User = Depends(current_user)) -> UserRoleResponse: return UserRoleResponse(role=user.role) +def get_current_token_expiration_jwt( + user: User | None, request: Request +) -> datetime | None: + if user is None: + return None + + try: + # Get the JWT from the cookie + jwt_token = request.cookies.get("fastapiusersauth") + if not jwt_token: + logger.error("No JWT token found in cookies") + return None + + # Decode the JWT + decoded_token = jwt.decode(jwt_token, options={"verify_signature": False}) + + # Get the 'exp' (expiration) claim from the token + exp = decoded_token.get("exp") + if exp: + return datetime.fromtimestamp(exp) + else: + logger.error("No 'exp' claim found in JWT") + return None + + except Exception as e: + logger.error(f"Error decoding JWT: {e}") + return None + + def get_current_token_creation( user: User | None, db_session: Session ) -> datetime | None: @@ -357,6 +421,7 @@ def get_current_token_creation( @router.get("/me") def verify_user_logged_in( + request: Request, user: User | None = Depends(optional_user), db_session: Session = Depends(get_session), ) -> UserInfo: @@ -380,7 +445,9 @@ def verify_user_logged_in( detail="Access denied. User's OIDC token has expired.", ) - token_created_at = get_current_token_creation(user, db_session) + token_created_at = ( + None if MULTI_TENANT else get_current_token_creation(user, db_session) + ) user_info = UserInfo.from_model( user, current_token_created_at=token_created_at, diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index 91efe6cb874..49603fa3971 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -73,6 +73,7 @@ from danswer.server.query_and_chat.token_limit import check_token_rate_limits from danswer.utils.logger import setup_logger + logger = setup_logger() router = APIRouter(prefix="/chat") diff --git a/backend/danswer/server/utils.py b/backend/danswer/server/utils.py index 53ed5b426ba..70404537f70 100644 --- a/backend/danswer/server/utils.py +++ b/backend/danswer/server/utils.py @@ -1,7 +1,17 @@ import json +import smtplib from datetime import datetime +from email.mime.multipart import MIMEMultipart +from email.mime.text import MIMEText from typing import Any +from danswer.configs.app_configs import SMTP_PASS +from danswer.configs.app_configs import SMTP_PORT +from danswer.configs.app_configs import SMTP_SERVER +from danswer.configs.app_configs import SMTP_USER +from danswer.configs.app_configs import WEB_DOMAIN +from danswer.db.models import User + class DateTimeEncoder(json.JSONEncoder): """Custom JSON encoder that converts datetime objects to ISO format strings.""" @@ -43,3 +53,28 @@ def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, str]: masked_creds[key] = mask_string(val) return masked_creds + + +def send_user_email_invite(user_email: str, current_user: User) -> None: + msg = MIMEMultipart() + msg["Subject"] = "Invitation to Join Danswer Workspace" + msg["To"] = user_email + msg["From"] = current_user.email + + email_body = f""" +Hello, + +You have been invited to join a workspace on Danswer. + +To join the workspace, please do so at the following link: +{WEB_DOMAIN}/auth/login + +Best regards, +The Danswer Team""" + + msg.attach(MIMEText(email_body, "plain")) + + with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as smtp_server: + smtp_server.starttls() + smtp_server.login(SMTP_USER, SMTP_PASS) + smtp_server.send_message(msg) diff --git a/backend/danswer/setup.py b/backend/danswer/setup.py index 2baeda4a811..443ab501d6b 100644 --- a/backend/danswer/setup.py +++ b/backend/danswer/setup.py @@ -4,6 +4,7 @@ from danswer.chat.load_yamls import load_chat_yamls from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP +from danswer.configs.app_configs import MULTI_TENANT from danswer.configs.constants import KV_REINDEX_KEY from danswer.configs.constants import KV_SEARCH_SETTINGS from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION @@ -98,7 +99,8 @@ def setup_danswer(db_session: Session) -> None: # Does the user need to trigger a reindexing to bring the document index # into a good state, marked in the kv store - mark_reindex_flag(db_session) + if not MULTI_TENANT: + mark_reindex_flag(db_session) # ensure Vespa is setup correctly logger.notice("Verifying Document Index(s) is/are available.") diff --git a/backend/ee/danswer/background/celery/celery_app.py b/backend/ee/danswer/background/celery/celery_app.py index 5dd0f72009f..de57794ee5a 100644 --- a/backend/ee/danswer/background/celery/celery_app.py +++ b/backend/ee/danswer/background/celery/celery_app.py @@ -1,12 +1,12 @@ from datetime import timedelta -from sqlalchemy.orm import Session - from danswer.background.celery.celery_app import celery_app from danswer.background.task_utils import build_celery_task_wrapper +from danswer.background.update import get_all_tenant_ids from danswer.configs.app_configs import JOB_TIMEOUT +from danswer.configs.app_configs import MULTI_TENANT from danswer.db.chat import delete_chat_sessions_older_than -from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.engine import get_session_with_tenant from danswer.server.settings.store import load_settings from danswer.utils.logger import setup_logger from danswer.utils.variable_functionality import global_version @@ -32,6 +32,7 @@ run_external_group_permission_sync, ) from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report +from shared_configs.configs import current_tenant_id logger = setup_logger() @@ -41,22 +42,26 @@ @build_celery_task_wrapper(name_sync_external_doc_permissions_task) @celery_app.task(soft_time_limit=JOB_TIMEOUT) -def sync_external_doc_permissions_task(cc_pair_id: int) -> None: - with Session(get_sqlalchemy_engine()) as db_session: +def sync_external_doc_permissions_task(cc_pair_id: int, tenant_id: str | None) -> None: + with get_session_with_tenant(tenant_id) as db_session: run_external_doc_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id) @build_celery_task_wrapper(name_sync_external_group_permissions_task) @celery_app.task(soft_time_limit=JOB_TIMEOUT) -def sync_external_group_permissions_task(cc_pair_id: int) -> None: - with Session(get_sqlalchemy_engine()) as db_session: +def sync_external_group_permissions_task( + cc_pair_id: int, tenant_id: str | None +) -> None: + with get_session_with_tenant(tenant_id) as db_session: run_external_group_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id) @build_celery_task_wrapper(name_chat_ttl_task) @celery_app.task(soft_time_limit=JOB_TIMEOUT) -def perform_ttl_management_task(retention_limit_days: int) -> None: - with Session(get_sqlalchemy_engine()) as db_session: +def perform_ttl_management_task( + retention_limit_days: int, tenant_id: str | None +) -> None: + with get_session_with_tenant(tenant_id) as db_session: delete_chat_sessions_older_than(retention_limit_days, db_session) @@ -67,16 +72,16 @@ def perform_ttl_management_task(retention_limit_days: int) -> None: name="check_sync_external_doc_permissions_task", soft_time_limit=JOB_TIMEOUT, ) -def check_sync_external_doc_permissions_task() -> None: +def check_sync_external_doc_permissions_task(tenant_id: str | None) -> None: """Runs periodically to sync external permissions""" - with Session(get_sqlalchemy_engine()) as db_session: + with get_session_with_tenant(tenant_id) as db_session: cc_pairs = get_all_auto_sync_cc_pairs(db_session) for cc_pair in cc_pairs: if should_perform_external_doc_permissions_check( cc_pair=cc_pair, db_session=db_session ): sync_external_doc_permissions_task.apply_async( - kwargs=dict(cc_pair_id=cc_pair.id), + kwargs=dict(cc_pair_id=cc_pair.id, tenant_id=tenant_id), ) @@ -84,16 +89,16 @@ def check_sync_external_doc_permissions_task() -> None: name="check_sync_external_group_permissions_task", soft_time_limit=JOB_TIMEOUT, ) -def check_sync_external_group_permissions_task() -> None: +def check_sync_external_group_permissions_task(tenant_id: str | None) -> None: """Runs periodically to sync external group permissions""" - with Session(get_sqlalchemy_engine()) as db_session: + with get_session_with_tenant(tenant_id) as db_session: cc_pairs = get_all_auto_sync_cc_pairs(db_session) for cc_pair in cc_pairs: if should_perform_external_group_permissions_check( cc_pair=cc_pair, db_session=db_session ): sync_external_group_permissions_task.apply_async( - kwargs=dict(cc_pair_id=cc_pair.id), + kwargs=dict(cc_pair_id=cc_pair.id, tenant_id=tenant_id), ) @@ -101,25 +106,33 @@ def check_sync_external_group_permissions_task() -> None: name="check_ttl_management_task", soft_time_limit=JOB_TIMEOUT, ) -def check_ttl_management_task() -> None: +def check_ttl_management_task(tenant_id: str | None) -> None: """Runs periodically to check if any ttl tasks should be run and adds them to the queue""" + token = None + if MULTI_TENANT and tenant_id is not None: + token = current_tenant_id.set(tenant_id) + settings = load_settings() retention_limit_days = settings.maximum_chat_retention_days - with Session(get_sqlalchemy_engine()) as db_session: + with get_session_with_tenant(tenant_id) as db_session: if should_perform_chat_ttl_check(retention_limit_days, db_session): perform_ttl_management_task.apply_async( - kwargs=dict(retention_limit_days=retention_limit_days), + kwargs=dict( + retention_limit_days=retention_limit_days, tenant_id=tenant_id + ), ) + if token is not None: + current_tenant_id.reset(token) @celery_app.task( name="autogenerate_usage_report_task", soft_time_limit=JOB_TIMEOUT, ) -def autogenerate_usage_report_task() -> None: +def autogenerate_usage_report_task(tenant_id: str | None) -> None: """This generates usage report under the /admin/generate-usage/report endpoint""" - with Session(get_sqlalchemy_engine()) as db_session: + with get_session_with_tenant(tenant_id) as db_session: create_new_usage_report( db_session=db_session, user_id=None, @@ -130,22 +143,48 @@ def autogenerate_usage_report_task() -> None: ##### # Celery Beat (Periodic Tasks) Settings ##### -celery_app.conf.beat_schedule = { - "sync-external-doc-permissions": { + + +tenant_ids = get_all_tenant_ids() + +tasks_to_schedule = [ + { + "name": "sync-external-doc-permissions", "task": "check_sync_external_doc_permissions_task", "schedule": timedelta(seconds=5), # TODO: optimize this }, - "sync-external-group-permissions": { + { + "name": "sync-external-group-permissions", "task": "check_sync_external_group_permissions_task", "schedule": timedelta(seconds=5), # TODO: optimize this }, - "autogenerate_usage_report": { + { + "name": "autogenerate_usage_report", "task": "autogenerate_usage_report_task", "schedule": timedelta(days=30), # TODO: change this to config flag }, - "check-ttl-management": { + { + "name": "check-ttl-management", "task": "check_ttl_management_task", "schedule": timedelta(hours=1), }, - **(celery_app.conf.beat_schedule or {}), -} +] + +# Build the celery beat schedule dynamically +beat_schedule = {} + +for tenant_id in tenant_ids: + for task in tasks_to_schedule: + task_name = f"{task['name']}-{tenant_id}" # Unique name for each scheduled task + beat_schedule[task_name] = { + "task": task["task"], + "schedule": task["schedule"], + "args": (tenant_id,), # Must pass tenant_id as an argument + } + +# Include any existing beat schedules +existing_beat_schedule = celery_app.conf.beat_schedule or {} +beat_schedule.update(existing_beat_schedule) + +# Update the Celery app configuration +celery_app.conf.beat_schedule = beat_schedule diff --git a/backend/ee/danswer/background/task_name_builders.py b/backend/ee/danswer/background/task_name_builders.py index c494329d366..7a8eee0cd70 100644 --- a/backend/ee/danswer/background/task_name_builders.py +++ b/backend/ee/danswer/background/task_name_builders.py @@ -2,9 +2,13 @@ def name_chat_ttl_task(retention_limit_days: int) -> str: return f"chat_ttl_{retention_limit_days}_days" -def name_sync_external_doc_permissions_task(cc_pair_id: int) -> str: +def name_sync_external_doc_permissions_task( + cc_pair_id: int, tenant_id: str | None = None +) -> str: return f"sync_external_doc_permissions_task__{cc_pair_id}" -def name_sync_external_group_permissions_task(cc_pair_id: int) -> str: +def name_sync_external_group_permissions_task( + cc_pair_id: int, tenant_id: str | None = None +) -> str: return f"sync_external_group_permissions_task__{cc_pair_id}" diff --git a/backend/ee/danswer/main.py b/backend/ee/danswer/main.py index 8422d5494ae..e6483f75ae1 100644 --- a/backend/ee/danswer/main.py +++ b/backend/ee/danswer/main.py @@ -4,6 +4,7 @@ from danswer.auth.users import auth_backend from danswer.auth.users import fastapi_users from danswer.configs.app_configs import AUTH_TYPE +from danswer.configs.app_configs import MULTI_TENANT from danswer.configs.app_configs import OAUTH_CLIENT_ID from danswer.configs.app_configs import OAUTH_CLIENT_SECRET from danswer.configs.app_configs import USER_AUTH_SECRET @@ -24,6 +25,7 @@ basic_router as enterprise_settings_router, ) from ee.danswer.server.manage.standard_answer import router as standard_answer_router +from ee.danswer.server.middleware.tenant_tracking import add_tenant_id_middleware from ee.danswer.server.query_and_chat.chat_backend import ( router as chat_router, ) @@ -53,6 +55,9 @@ def get_application() -> FastAPI: application = get_application_base() + if MULTI_TENANT: + add_tenant_id_middleware(application, logger) + if AUTH_TYPE == AuthType.OIDC: include_router_with_global_prefix_prepended( application, diff --git a/backend/ee/danswer/server/middleware/tenant_tracking.py b/backend/ee/danswer/server/middleware/tenant_tracking.py new file mode 100644 index 00000000000..f564a4fc683 --- /dev/null +++ b/backend/ee/danswer/server/middleware/tenant_tracking.py @@ -0,0 +1,60 @@ +import logging +from collections.abc import Awaitable +from collections.abc import Callable + +import jwt +from fastapi import FastAPI +from fastapi import HTTPException +from fastapi import Request +from fastapi import Response + +from danswer.configs.app_configs import MULTI_TENANT +from danswer.configs.app_configs import SECRET_JWT_KEY +from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA +from danswer.db.engine import is_valid_schema_name +from shared_configs.configs import current_tenant_id + + +def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> None: + @app.middleware("http") + async def set_tenant_id( + request: Request, call_next: Callable[[Request], Awaitable[Response]] + ) -> Response: + try: + logger.info(f"Request route: {request.url.path}") + + if not MULTI_TENANT: + tenant_id = POSTGRES_DEFAULT_SCHEMA + else: + token = request.cookies.get("tenant_details") + if token: + try: + payload = jwt.decode( + token, SECRET_JWT_KEY, algorithms=["HS256"] + ) + tenant_id = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA) + if not is_valid_schema_name(tenant_id): + raise HTTPException( + status_code=400, detail="Invalid tenant ID format" + ) + except jwt.InvalidTokenError: + tenant_id = POSTGRES_DEFAULT_SCHEMA + except Exception as e: + logger.error( + f"Unexpected error in set_tenant_id_middleware: {str(e)}" + ) + raise HTTPException( + status_code=500, detail="Internal server error" + ) + else: + tenant_id = POSTGRES_DEFAULT_SCHEMA + + current_tenant_id.set(tenant_id) + logger.info(f"Middleware set current_tenant_id to: {tenant_id}") + + response = await call_next(request) + return response + + except Exception as e: + logger.error(f"Error in tenant ID middleware: {str(e)}") + raise diff --git a/backend/ee/danswer/server/tenants/api.py b/backend/ee/danswer/server/tenants/api.py index ec96351856b..b522112ae06 100644 --- a/backend/ee/danswer/server/tenants/api.py +++ b/backend/ee/danswer/server/tenants/api.py @@ -8,8 +8,11 @@ from danswer.setup import setup_danswer from danswer.utils.logger import setup_logger from ee.danswer.server.tenants.models import CreateTenantRequest +from ee.danswer.server.tenants.provisioning import add_users_to_tenant from ee.danswer.server.tenants.provisioning import ensure_schema_exists from ee.danswer.server.tenants.provisioning import run_alembic_migrations +from ee.danswer.server.tenants.provisioning import user_owns_a_tenant +from shared_configs.configs import current_tenant_id logger = setup_logger() router = APIRouter(prefix="/tenants") @@ -19,9 +22,15 @@ def create_tenant( create_tenant_request: CreateTenantRequest, _: None = Depends(control_plane_dep) ) -> dict[str, str]: - try: - tenant_id = create_tenant_request.tenant_id + tenant_id = create_tenant_request.tenant_id + email = create_tenant_request.initial_admin_email + token = None + if user_owns_a_tenant(email): + raise HTTPException( + status_code=409, detail="User already belongs to an organization" + ) + try: if not MULTI_TENANT: raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled") @@ -31,10 +40,14 @@ def create_tenant( logger.info(f"Schema already exists for tenant {tenant_id}") run_alembic_migrations(tenant_id) + token = current_tenant_id.set(tenant_id) + print("getting session", tenant_id) with get_session_with_tenant(tenant_id) as db_session: setup_danswer(db_session) logger.info(f"Tenant {tenant_id} created successfully") + add_users_to_tenant([email], tenant_id) + return { "status": "success", "message": f"Tenant {tenant_id} created successfully", @@ -44,3 +57,6 @@ def create_tenant( raise HTTPException( status_code=500, detail=f"Failed to create tenant: {str(e)}" ) + finally: + if token is not None: + current_tenant_id.reset(token) diff --git a/backend/ee/danswer/server/tenants/provisioning.py b/backend/ee/danswer/server/tenants/provisioning.py index 62436c92e17..77d27e7a551 100644 --- a/backend/ee/danswer/server/tenants/provisioning.py +++ b/backend/ee/danswer/server/tenants/provisioning.py @@ -8,7 +8,9 @@ from alembic import command from alembic.config import Config from danswer.db.engine import build_connection_string +from danswer.db.engine import get_session_with_tenant from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.models import UserTenantMapping from danswer.utils.logger import setup_logger logger = setup_logger() @@ -61,3 +63,48 @@ def ensure_schema_exists(tenant_id: str) -> bool: db_session.execute(stmt) return True return False + + +# For now, we're implementing a primitive mapping between users and tenants. +# This function is only used to determine a user's relationship to a tenant upon creation (implying ownership). +def user_owns_a_tenant(email: str) -> bool: + with get_session_with_tenant("public") as db_session: + result = ( + db_session.query(UserTenantMapping) + .filter(UserTenantMapping.email == email) + .first() + ) + return result is not None + + +def add_users_to_tenant(emails: list[str], tenant_id: str) -> None: + with get_session_with_tenant("public") as db_session: + try: + for email in emails: + db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id)) + except Exception as e: + logger.exception(f"Failed to add users to tenant {tenant_id}: {str(e)}") + db_session.commit() + + +def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None: + with get_session_with_tenant("public") as db_session: + try: + mappings_to_delete = ( + db_session.query(UserTenantMapping) + .filter( + UserTenantMapping.email.in_(emails), + UserTenantMapping.tenant_id == tenant_id, + ) + .all() + ) + + for mapping in mappings_to_delete: + db_session.delete(mapping) + + db_session.commit() + except Exception as e: + logger.exception( + f"Failed to remove users from tenant {tenant_id}: {str(e)}" + ) + db_session.rollback() diff --git a/backend/scripts/query_time_check/seed_dummy_docs.py b/backend/scripts/query_time_check/seed_dummy_docs.py index 96b6b4a0133..70cb2a4a6a8 100644 --- a/backend/scripts/query_time_check/seed_dummy_docs.py +++ b/backend/scripts/query_time_check/seed_dummy_docs.py @@ -94,6 +94,7 @@ def generate_dummy_chunk( ), document_sets={document_set for document_set in document_set_names}, boost=random.randint(-1, 1), + tenant_id="public", ) diff --git a/backend/shared_configs/configs.py b/backend/shared_configs/configs.py index 50233ab6878..898c01509aa 100644 --- a/backend/shared_configs/configs.py +++ b/backend/shared_configs/configs.py @@ -1,3 +1,4 @@ +import contextvars import os from typing import List from urllib.parse import urlparse @@ -109,3 +110,5 @@ def validate_cors_origin(origin: str) -> None: else: # If the environment variable is empty, allow all origins CORS_ALLOWED_ORIGIN = ["*"] + +current_tenant_id = contextvars.ContextVar("current_tenant_id", default="public") diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index 4d0eff8612d..5298859d13d 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -29,6 +29,7 @@ services: - SMTP_PORT=${SMTP_PORT:-587} # For sending verification emails, if unspecified then defaults to '587' - SMTP_USER=${SMTP_USER:-} - SMTP_PASS=${SMTP_PASS:-} + - ENABLE_EMAIL_INVITES=${ENABLE_EMAIL_INVITES:-} # If enabled, will send users (using SMTP settings) an email to join the workspace - EMAIL_FROM=${EMAIL_FROM:-} - OAUTH_CLIENT_ID=${OAUTH_CLIENT_ID:-} - OAUTH_CLIENT_SECRET=${OAUTH_CLIENT_SECRET:-} diff --git a/web/src/app/auth/create-account/page.tsx b/web/src/app/auth/create-account/page.tsx new file mode 100644 index 00000000000..b5c340afb6a --- /dev/null +++ b/web/src/app/auth/create-account/page.tsx @@ -0,0 +1,45 @@ +"use client"; + +import AuthFlowContainer from "@/components/auth/AuthFlowContainer"; +import { REGISTRATION_URL } from "@/lib/constants"; +import { Button } from "@tremor/react"; +import Link from "next/link"; +import { FiLogIn } from "react-icons/fi"; + +const Page = () => { + return ( + +
+

+ Account Not Found +

+

+ We couldn't find your account in our records. To access Danswer, + you need to either: +

+ +
+ + + +
+

+ Have an account with a different email?{" "} + + Sign in + +

+
+
+ ); +}; + +export default Page; diff --git a/web/src/app/auth/error/page.tsx b/web/src/app/auth/error/page.tsx index 4f288cd205f..c75e620068e 100644 --- a/web/src/app/auth/error/page.tsx +++ b/web/src/app/auth/error/page.tsx @@ -1,21 +1,49 @@ "use client"; +import AuthFlowContainer from "@/components/auth/AuthFlowContainer"; import { Button } from "@tremor/react"; import Link from "next/link"; import { FiLogIn } from "react-icons/fi"; const Page = () => { return ( -
-
- Unable to login, please try again and/or contact an administrator. + +
+

+ Authentication Error +

+

+ We encountered an issue while attempting to log you in. +

+
+

Possible Issues:

+
    +
  • +
    + Incorrect or expired login credentials +
  • +
  • +
    + Temporary authentication system disruption +
  • +
  • +
    + Account access restrictions or permissions +
  • +
+
+ + + + +

+ We recommend trying again. If you continue to experience problems, + please reach out to your system administrator for assistance. +

- - - -
+ ); }; diff --git a/web/src/app/auth/login/LoginText.tsx b/web/src/app/auth/login/LoginText.tsx index a875b407a65..e31aeb81321 100644 --- a/web/src/app/auth/login/LoginText.tsx +++ b/web/src/app/auth/login/LoginText.tsx @@ -6,11 +6,15 @@ import { SettingsContext } from "@/components/settings/SettingsProvider"; export const LoginText = () => { const settings = useContext(SettingsContext); - if (!settings) { - throw new Error("SettingsContext is not available"); - } + // if (!settings) { + // throw new Error("SettingsContext is not available"); + // } return ( - <>Log In to {settings?.enterpriseSettings?.application_name || "Danswer"} + <> + Log In to{" "} + {(settings && settings?.enterpriseSettings?.application_name) || + "Danswer"} + ); }; diff --git a/web/src/app/auth/login/page.tsx b/web/src/app/auth/login/page.tsx index 50f1d42d9b4..9ec047d61e2 100644 --- a/web/src/app/auth/login/page.tsx +++ b/web/src/app/auth/login/page.tsx @@ -14,6 +14,7 @@ import Link from "next/link"; import { Logo } from "@/components/Logo"; import { LoginText } from "./LoginText"; import { getSecondsUntilExpiration } from "@/lib/time"; +import AuthFlowContainer from "@/components/auth/AuthFlowContainer"; const Page = async ({ searchParams, @@ -51,7 +52,6 @@ const Page = async ({ if (authTypeMetadata?.requiresVerification && !currentUser.is_verified) { return redirect("/auth/waiting-on-verification"); } - return redirect("/"); } @@ -70,46 +70,44 @@ const Page = async ({ } return ( -
+
-
-
- - {authUrl && authTypeMetadata && ( - <> -

- -

- - - )} - {authTypeMetadata?.authType === "basic" && ( - -
- - <LoginText /> - -
- -
- - Don't have an account?{" "} - - Create an account - - -
-
- )} -
+
+ {authUrl && authTypeMetadata && ( + <> +

+ +

+ + + + )} + {authTypeMetadata?.authType === "basic" && ( + +
+ + <LoginText /> + +
+ +
+ + Don't have an account?{" "} + + Create an account + + +
+
+ )}
-
+ ); }; diff --git a/web/src/app/auth/oauth/callback/route.ts b/web/src/app/auth/oauth/callback/route.ts index 0b4157731a1..6e8f290a65f 100644 --- a/web/src/app/auth/oauth/callback/route.ts +++ b/web/src/app/auth/oauth/callback/route.ts @@ -11,6 +11,12 @@ export const GET = async (request: NextRequest) => { const response = await fetch(url.toString()); const setCookieHeader = response.headers.get("set-cookie"); + if (response.status === 401) { + return NextResponse.redirect( + new URL("/auth/create-account", getDomain(request)) + ); + } + if (!setCookieHeader) { return NextResponse.redirect(new URL("/auth/error", getDomain(request))); } diff --git a/web/src/app/auth/signup/page.tsx b/web/src/app/auth/signup/page.tsx index 9a2631c4350..ec276a09672 100644 --- a/web/src/app/auth/signup/page.tsx +++ b/web/src/app/auth/signup/page.tsx @@ -10,6 +10,7 @@ import { EmailPasswordForm } from "../login/EmailPasswordForm"; import { Card, Title, Text } from "@tremor/react"; import Link from "next/link"; import { Logo } from "@/components/Logo"; +import { CLOUD_ENABLED } from "@/lib/constants"; const Page = async () => { // catch cases where the backend is completely unreachable here @@ -25,6 +26,9 @@ const Page = async () => { } catch (e) { console.log(`Some fetch failed for the login page - ${e}`); } + if (CLOUD_ENABLED) { + return redirect("/auth/login"); + } // simply take the user to the home page if Auth is disabled if (authTypeMetadata?.authType === "disabled") { diff --git a/web/src/app/layout.tsx b/web/src/app/layout.tsx index f64edd17964..f49864aac75 100644 --- a/web/src/app/layout.tsx +++ b/web/src/app/layout.tsx @@ -19,6 +19,8 @@ import { HeaderTitle } from "@/components/header/HeaderTitle"; import { Logo } from "@/components/Logo"; import { UserProvider } from "@/components/user/UserProvider"; import { ProviderContextProvider } from "@/components/chat_search/ProviderContext"; +import { redirect } from "next/navigation"; +import { headers } from "next/headers"; const inter = Inter({ subsets: ["latin"], @@ -56,8 +58,6 @@ export default async function RootLayout({ const combinedSettings = await fetchSettingsSS(); if (!combinedSettings) { - // Just display a simple full page error if fetching fails. - return ( diff --git a/web/src/components/auth/AuthFlowContainer.tsx b/web/src/components/auth/AuthFlowContainer.tsx new file mode 100644 index 00000000000..35fd3d6f3c3 --- /dev/null +++ b/web/src/components/auth/AuthFlowContainer.tsx @@ -0,0 +1,16 @@ +import { Logo } from "../Logo"; + +export default function AuthFlowContainer({ + children, +}: { + children: React.ReactNode; +}) { + return ( +
+
+ + {children} +
+
+ ); +} diff --git a/web/src/components/settings/lib.ts b/web/src/components/settings/lib.ts index f4e14699f1a..1c1ec9249f1 100644 --- a/web/src/components/settings/lib.ts +++ b/web/src/components/settings/lib.ts @@ -40,7 +40,7 @@ export async function fetchSettingsSS(): Promise { let settings: Settings; if (!results[0].ok) { - if (results[0].status === 403) { + if (results[0].status === 403 || results[0].status === 401) { settings = { gpu_enabled: false, chat_page_enabled: true, @@ -62,7 +62,7 @@ export async function fetchSettingsSS(): Promise { let enterpriseSettings: EnterpriseSettings | null = null; if (tasks.length > 1) { if (!results[1].ok) { - if (results[1].status !== 403) { + if (results[1].status !== 403 && results[1].status !== 401) { throw new Error( `fetchEnterpriseSettingsSS failed: status=${results[1].status} body=${await results[1].text()}` ); diff --git a/web/src/lib/constants.ts b/web/src/lib/constants.ts index 974695a8350..4fe1d616dcd 100644 --- a/web/src/lib/constants.ts +++ b/web/src/lib/constants.ts @@ -55,3 +55,7 @@ export const CUSTOM_ANALYTICS_ENABLED = process.env.CUSTOM_ANALYTICS_SECRET_KEY export const DISABLE_LLM_DOC_RELEVANCE = process.env.DISABLE_LLM_DOC_RELEVANCE?.toLowerCase() === "true"; + +export const CLOUD_ENABLED = process.env.NEXT_PUBLIC_CLOUD_ENABLED; +export const REGISTRATION_URL = + process.env.INTERNAL_URL || "http://127.0.0.1:3001";