Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tenant context #2596

Merged
merged 23 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions backend/alembic.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# A generic, single database configuration.

[alembic]
[DEFAULT]
# path to migration scripts
script_location = alembic

Expand Down Expand Up @@ -47,7 +47,8 @@ prepend_sys_path = .
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
version_path_separator = os
# Use os.pathsep. Default configuration used for new projects.

# set to 'true' to search source files recursively
# in each "version_locations" directory
Expand Down Expand Up @@ -106,3 +107,12 @@ formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S


[alembic]
script_location = alembic
version_locations = %(script_location)s/versions

[schema_private]
script_location = alembic_tenants
version_locations = %(script_location)s/versions
48 changes: 26 additions & 22 deletions backend/alembic/env.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
from typing import Any
import asyncio
from logging.config import fileConfig

from alembic import context
from danswer.db.engine import build_connection_string
from danswer.db.models import Base
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import create_async_engine
from celery.backends.database.session import ResultModelBase # type: ignore
from sqlalchemy.schema import SchemaItem
from sqlalchemy.sql import text

from danswer.configs.app_configs import MULTI_TENANT
from danswer.db.engine import build_connection_string
from danswer.db.models import Base
from celery.backends.database.session import ResultModelBase # type: ignore

# Alembic Config object
config = context.config

# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None and config.attributes.get(
"configure_logger", True
):
Expand All @@ -35,8 +36,7 @@ def get_schema_options() -> tuple[str, bool]:
for pair in arg.split(","):
if "=" in pair:
key, value = pair.split("=", 1)
x_args[key] = value

x_args[key.strip()] = value.strip()
schema_name = x_args.get("schema", "public")
create_schema = x_args.get("create_schema", "true").lower() == "true"
return schema_name, create_schema
Expand All @@ -46,11 +46,7 @@ def get_schema_options() -> tuple[str, bool]:


def include_object(
object: SchemaItem,
name: str,
type_: str,
reflected: bool,
compare_to: SchemaItem | None,
object: Any, name: str, type_: str, reflected: bool, compare_to: Any
) -> bool:
if type_ == "table" and name in EXCLUDE_TABLES:
return False
Expand All @@ -59,54 +55,63 @@ def include_object(

def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
schema_name, _ = get_schema_options()
url = build_connection_string()
schema, _ = get_schema_options()

context.configure(
url=url,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
include_object=include_object,
dialect_opts={"paramstyle": "named"},
version_table_schema=schema,
version_table_schema=schema_name,
include_schemas=True,
script_location=config.get_main_option("script_location"),
dialect_opts={"paramstyle": "named"},
)

with context.begin_transaction():
context.run_migrations()


def do_run_migrations(connection: Connection) -> None:
schema, create_schema = get_schema_options()
schema_name, create_schema = get_schema_options()

if MULTI_TENANT and schema_name == "public":
raise ValueError(
"Cannot run default migrations in public schema when multi-tenancy is enabled. "
"Please specify a tenant-specific schema."
)

if create_schema:
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema}"'))
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
connection.execute(text("COMMIT"))

connection.execute(text(f'SET search_path TO "{schema}"'))
# Set search_path to the target schema
connection.execute(text(f'SET search_path TO "{schema_name}"'))

context.configure(
connection=connection,
target_metadata=target_metadata, # type: ignore
version_table_schema=schema,
include_object=include_object,
version_table_schema=schema_name,
include_schemas=True,
compare_type=True,
compare_server_default=True,
script_location=config.get_main_option("script_location"),
)

with context.begin_transaction():
context.run_migrations()


async def run_async_migrations() -> None:
"""Run migrations in 'online' mode."""
connectable = create_async_engine(
build_connection_string(),
poolclass=pool.NullPool,
Expand All @@ -119,7 +124,6 @@ async def run_async_migrations() -> None:


def run_migrations_online() -> None:
"""Run migrations in 'online' mode."""
asyncio.run(run_async_migrations())


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
def upgrade() -> None:
conn = op.get_bind()
existing_ids_and_chosen_assistants = conn.execute(
sa.text("select id, chosen_assistants from public.user")
sa.text('select id, chosen_assistants from "user"')
)
op.drop_column(
"user",
Expand All @@ -37,7 +37,7 @@ def upgrade() -> None:
for id, chosen_assistants in existing_ids_and_chosen_assistants:
conn.execute(
sa.text(
"update public.user set chosen_assistants = :chosen_assistants where id = :id"
'update "user" set chosen_assistants = :chosen_assistants where id = :id'
),
{"chosen_assistants": json.dumps(chosen_assistants), "id": id},
)
Expand All @@ -46,7 +46,7 @@ def upgrade() -> None:
def downgrade() -> None:
conn = op.get_bind()
existing_ids_and_chosen_assistants = conn.execute(
sa.text("select id, chosen_assistants from public.user")
sa.text('select id, chosen_assistants from "user"')
)
op.drop_column(
"user",
Expand All @@ -59,7 +59,7 @@ def downgrade() -> None:
for id, chosen_assistants in existing_ids_and_chosen_assistants:
conn.execute(
sa.text(
"update public.user set chosen_assistants = :chosen_assistants where id = :id"
'update "user" set chosen_assistants = :chosen_assistants where id = :id'
),
{"chosen_assistants": chosen_assistants, "id": id},
)
3 changes: 3 additions & 0 deletions backend/alembic_tenants/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
These files are for public table migrations when operating with multi tenancy.

If you are not a Danswer developer, you can ignore this directory entirely.
111 changes: 111 additions & 0 deletions backend/alembic_tenants/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import asyncio
from logging.config import fileConfig

from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.schema import SchemaItem

from alembic import context
from danswer.db.engine import build_connection_string
from danswer.db.models import PublicBase

# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config

# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None and config.attributes.get(
"configure_logger", True
):
fileConfig(config.config_file_name)

# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = [PublicBase.metadata]

# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.

EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}


def include_object(
object: SchemaItem,
name: str,
type_: str,
reflected: bool,
compare_to: SchemaItem | None,
) -> bool:
if type_ == "table" and name in EXCLUDE_TABLES:
return False
return True


def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = build_connection_string()
context.configure(
url=url,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)

with context.begin_transaction():
context.run_migrations()


def do_run_migrations(connection: Connection) -> None:
context.configure(
connection=connection,
target_metadata=target_metadata, # type: ignore
include_object=include_object,
) # type: ignore

with context.begin_transaction():
context.run_migrations()


async def run_async_migrations() -> None:
"""In this scenario we need to create an Engine
and associate a connection with the context.
"""

connectable = create_async_engine(
build_connection_string(),
poolclass=pool.NullPool,
)

async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)

await connectable.dispose()


def run_migrations_online() -> None:
"""Run migrations in 'online' mode."""

asyncio.run(run_async_migrations())


if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
24 changes: 24 additions & 0 deletions backend/alembic_tenants/script.py.mako
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""${message}

Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}

"""
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}

# revision identifiers, used by Alembic.
revision = ${repr(up_revision)}
down_revision = ${repr(down_revision)}
branch_labels = ${repr(branch_labels)}
depends_on = ${repr(depends_on)}


def upgrade() -> None:
${upgrades if upgrades else "pass"}


def downgrade() -> None:
${downgrades if downgrades else "pass"}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import sqlalchemy as sa

from alembic import op

# revision identifiers, used by Alembic.
revision = "14a83a331951"
down_revision = None
branch_labels = None
depends_on = None


def upgrade() -> None:
op.create_table(
"user_tenant_mapping",
sa.Column("email", sa.String(), nullable=False),
sa.Column("tenant_id", sa.String(), nullable=False),
sa.UniqueConstraint("email", "tenant_id", name="uq_user_tenant"),
sa.UniqueConstraint("email", name="uq_email"),
schema="public",
)


def downgrade() -> None:
op.drop_table("user_tenant_mapping", schema="public")
1 change: 1 addition & 0 deletions backend/danswer/auth/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
class UserCreate(schemas.BaseUserCreate):
role: UserRole = UserRole.BASIC
has_web_login: bool | None = True
tenant_id: str | None = None


class UserUpdate(schemas.BaseUserUpdate):
Expand Down
Loading
Loading