Skip to content

Commit

Permalink
add tenant expulsion/gating + invite user -> increment billing seat no.
Browse files Browse the repository at this point in the history
  • Loading branch information
pablodanswer committed Oct 10, 2024
1 parent a1289bd commit e0f4e1f
Show file tree
Hide file tree
Showing 10 changed files with 110 additions and 124 deletions.
30 changes: 0 additions & 30 deletions backend/danswer/auth/users.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import smtplib
import uuid
from collections.abc import AsyncGenerator
Expand All @@ -10,7 +9,6 @@
from typing import Tuple

import jwt
import requests
from email_validator import EmailNotValidError
from email_validator import EmailUndeliverableError
from email_validator import validate_email
Expand All @@ -35,7 +33,6 @@
from fastapi_users.authentication.strategy.db import DatabaseStrategy
from fastapi_users.openapi import OpenAPIResponseType
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
from requests import HTTPError
from sqlalchemy import select
from sqlalchemy.orm import attributes
from sqlalchemy.orm import Session
Expand All @@ -45,7 +42,6 @@
from danswer.auth.schemas import UserRole
from danswer.auth.schemas import UserUpdate
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from danswer.configs.app_configs import DISABLE_AUTH
from danswer.configs.app_configs import EMAIL_FROM
from danswer.configs.app_configs import MULTI_TENANT
Expand Down Expand Up @@ -212,32 +208,6 @@ def send_user_verification_email(
s.send_message(msg)


def register_tenant_users(tenant_id: str, number_of_users: int) -> None:
"""
Send a request to the control service to register the number of users for a tenant.
"""
url = f"{CONTROL_PLANE_API_BASE_URL}/register-tenant-users"
payload = {"tenant_id": tenant_id, "number_of_users": number_of_users}

try:
response = requests.post(url, json=payload)
response.raise_for_status()
except HTTPError as e:
if e.response.status_code == 403:
try:
error_detail = e.response.json().get("detail", str(e))
except json.JSONDecodeError:
error_detail = str(e)
raise Exception(f"{error_detail}")

logger.error(f"Error registering tenant users: {str(e)}")
raise

except Exception as e:
logger.error(f"Unexpected error registering tenant users: {str(e)}")
raise


class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = USER_AUTH_SECRET
verification_token_secret = USER_AUTH_SECRET
Expand Down
4 changes: 2 additions & 2 deletions backend/danswer/server/manage/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_user
from danswer.auth.users import optional_user
from danswer.auth.users import register_tenant_users
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import ENABLE_EMAIL_INVITES
from danswer.configs.app_configs import MULTI_TENANT
Expand Down Expand Up @@ -64,6 +63,7 @@
from ee.danswer.db.api_key import is_api_key_email_address
from ee.danswer.db.external_perm import delete_user__ext_group_for_user__no_commit
from ee.danswer.db.user_group import remove_curator_status__no_commit
from ee.danswer.server.tenants.billing import register_tenant_users
from ee.danswer.server.tenants.provisioning import add_users_to_tenant
from ee.danswer.server.tenants.provisioning import remove_users_from_tenant

Expand Down Expand Up @@ -216,8 +216,8 @@ def bulk_invite_users(
initial_invited_users = get_invited_users()

all_emails = list(set(normalized_emails) | set(initial_invited_users))

number_of_invited_users = write_invited_users(all_emails)

if not MULTI_TENANT:
return number_of_invited_users
try:
Expand Down
1 change: 1 addition & 0 deletions backend/danswer/server/settings/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class Settings(BaseModel):
default_page: PageType = PageType.SEARCH
maximum_chat_retention_days: int | None = None
gpu_enabled: bool | None = None
product_gated: bool | None = None

def check_validity(self) -> None:
chat_page_enabled = self.chat_page_enabled
Expand Down
79 changes: 31 additions & 48 deletions backend/ee/danswer/server/tenants/api.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,44 @@
from typing import cast

import stripe
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session

from danswer.auth.users import current_admin_user
from danswer.auth.users import User
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.db.auth import get_total_users
from danswer.db.engine import get_session
from danswer.db.engine import get_session_with_tenant
from danswer.server.settings.store import load_settings
from danswer.server.settings.store import store_settings
from danswer.setup import setup_danswer
from danswer.utils.logger import setup_logger
from ee.danswer.configs.app_configs import STRIPE_PRICE_ID
from ee.danswer.configs.app_configs import STRIPE_SECRET_KEY
from ee.danswer.server.tenants.access import control_plane_dep
from ee.danswer.server.tenants.billing import fetch_billing_information
from ee.danswer.server.tenants.billing import fetch_tenant_stripe_information
from ee.danswer.server.tenants.models import BillingInformation
from ee.danswer.server.tenants.models import CheckoutSessionCreationRequest
from ee.danswer.server.tenants.models import CheckoutSessionCreationResponse
from ee.danswer.server.tenants.models import CreateTenantRequest
from ee.danswer.server.tenants.models import ProductGatingRequest
from ee.danswer.server.tenants.provisioning import add_users_to_tenant
from ee.danswer.server.tenants.provisioning import ensure_schema_exists
from ee.danswer.server.tenants.provisioning import run_alembic_migrations
from ee.danswer.server.tenants.provisioning import user_owns_a_tenant
from shared_configs.configs import current_tenant_id


stripe.api_key = STRIPE_SECRET_KEY

logger = setup_logger()
router = APIRouter(prefix="/tenants")
stripe.api_key = STRIPE_SECRET_KEY


@router.post("/create")
def create_tenant(
create_tenant_request: CreateTenantRequest, _: None = Depends(control_plane_dep)
) -> dict[str, str]:
if not MULTI_TENANT:
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")

tenant_id = create_tenant_request.tenant_id
email = create_tenant_request.initial_admin_email
token = None
Expand All @@ -49,17 +49,14 @@ def create_tenant(
)

try:
if not MULTI_TENANT:
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")

if not ensure_schema_exists(tenant_id):
logger.info(f"Created schema for tenant {tenant_id}")
else:
logger.info(f"Schema already exists for tenant {tenant_id}")

run_alembic_migrations(tenant_id)
token = current_tenant_id.set(tenant_id)
print("getting session", tenant_id)
run_alembic_migrations(tenant_id)

with get_session_with_tenant(tenant_id) as db_session:
setup_danswer(db_session)

Expand All @@ -79,42 +76,27 @@ def create_tenant(
current_tenant_id.reset(token)


@router.post("/update-subscription-quantity")
async def update_subscription_quantity(
checkout_session_creation_request: CheckoutSessionCreationRequest,
db_session: Session = Depends(get_session),
_: User = Depends(current_admin_user),
) -> CheckoutSessionCreationResponse:
current_seats = get_total_users(db_session)
@router.post("/product-gating")
def gate_product(
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
) -> None:
"""
Gating the product means that the product is not available to the tenant.
They will be directed to the billing page.
We gate the product when
1) User has ended free trial without adding payment method
2) User's card has declined
"""
if not MULTI_TENANT:
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
token = current_tenant_id.set(current_tenant_id.get())

if current_seats > checkout_session_creation_request.quantity:
raise HTTPException(
status_code=400,
detail="Too many users are active to downgrade to this quantity.",
)
settings = load_settings()
settings.product_gated = product_gating_request.gate_product
store_settings(settings)

try:
tenant_id = current_tenant_id.get()
response = fetch_tenant_stripe_information(tenant_id)
stripe_subscription_id = cast(str, response.get("stripe_subscription_id"))

subscription = stripe.Subscription.retrieve(stripe_subscription_id)
updated_subscription = stripe.Subscription.modify(
stripe_subscription_id,
items=[
{
"id": subscription["items"]["data"][0].id,
"price": STRIPE_PRICE_ID,
"quantity": checkout_session_creation_request.quantity,
}
],
metadata={"tenant_id": str(tenant_id)},
)

return CheckoutSessionCreationResponse(id=updated_subscription.id)
except Exception as e:
logger.exception("Failed to create checkout session")
raise HTTPException(status_code=500, detail=str(e))
if token is not None:
current_tenant_id.reset(token)


@router.get("/billing-information", response_model=BillingInformation)
Expand All @@ -140,4 +122,5 @@ async def create_customer_portal_session(_: User = Depends(current_admin_user))
logger.info(portal_session)
return {"url": portal_session.url}
except Exception as e:
logger.exception("Failed to create customer portal session")
raise HTTPException(status_code=500, detail=str(e))
32 changes: 32 additions & 0 deletions backend/ee/danswer/server/tenants/billing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
from typing import cast

import requests
import stripe

from danswer.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from danswer.utils.logger import setup_logger
from ee.danswer.configs.app_configs import STRIPE_PRICE_ID
from ee.danswer.configs.app_configs import STRIPE_SECRET_KEY
from ee.danswer.server.tenants.access import generate_data_plane_token
from shared_configs.configs import current_tenant_id

stripe.api_key = STRIPE_SECRET_KEY

logger = setup_logger()

Expand Down Expand Up @@ -33,3 +41,27 @@ def fetch_billing_information(tenant_id: str) -> dict:
response.raise_for_status()
billing_info = response.json()
return billing_info


def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription:
"""
Send a request to the control service to register the number of users for a tenant.
"""

tenant_id = current_tenant_id.get()
response = fetch_tenant_stripe_information(tenant_id)
stripe_subscription_id = cast(str, response.get("stripe_subscription_id"))

subscription = stripe.Subscription.retrieve(stripe_subscription_id)
updated_subscription = stripe.Subscription.modify(
stripe_subscription_id,
items=[
{
"id": subscription["items"]["data"][0].id,
"price": STRIPE_PRICE_ID,
"quantity": number_of_users,
}
],
metadata={"tenant_id": str(tenant_id)},
)
return updated_subscription
5 changes: 5 additions & 0 deletions backend/ee/danswer/server/tenants/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ class CreateTenantRequest(BaseModel):
initial_admin_email: str


class ProductGatingRequest(BaseModel):
tenant_id: str
gate_product: bool


class BillingInformation(BaseModel):
seats: int
subscription_status: str
Expand Down
1 change: 1 addition & 0 deletions web/src/app/admin/settings/interfaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ export interface Settings {
notifications: Notification[];
needs_reindexing: boolean;
gpu_enabled: boolean;
product_gated: boolean;
}

export interface Notification {
Expand Down
55 changes: 13 additions & 42 deletions web/src/app/ee/admin/cloud-settings/BillingInformationPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -59,33 +59,6 @@ export default function BillingInformationPage() {
return <div>Loading...</div>;
}

const handleUpgrade = async () => {
try {
const stripe = await stripePromise;
if (!stripe) throw new Error("Stripe failed to load");

const response = await updateSubscriptionQuantity(seats);
if (!response.ok) {
const errorData = await response.json();
throw new Error(errorData.detail);
}

// Allow time for Stripe webhook processing
setTimeout(refreshBillingInformation, 200);
setPopup({
message: "Subscription updated successfully",
type: "success",
});
} catch (error) {
console.error("Error updating subscription:", error);
setPopup({
message:
error instanceof Error ? error.message : "An unknown error occurred",
type: "error",
});
}
};

const handleManageSubscription = async () => {
try {
const response = await fetchCustomerPortal();
Expand Down Expand Up @@ -199,6 +172,7 @@ export default function BillingInformationPage() {
</p>
</div>
)}

{billingInformation.subscription_status === "trialing" ? (
<div className="bg-white p-5 rounded-lg shadow-sm transition-all duration-300 hover:shadow-md mt-8">
<p className="text-lg font-medium text-gray-700">
Expand All @@ -207,21 +181,18 @@ export default function BillingInformationPage() {
</div>
) : (
<div className="flex items-center space-x-4 mt-8">
<input
type="number"
min="1"
value={seats}
onChange={(e) => setSeats(Number(e.target.value))}
className="border border-gray-300 rounded-md px-4 py-2 w-32 focus:outline-none focus:ring-2 focus:ring-gray-500 bg-white text-gray-800 shadow-sm transition-all duration-300"
placeholder="Seats"
/>

<button
onClick={handleUpgrade}
className="bg-gray-600 text-white px-6 py-2 rounded-md hover:bg-gray-700 transition duration-300 ease-in-out focus:outline-none focus:ring-2 focus:ring-gray-500 focus:ring-opacity-50 font-medium shadow-md text-lg"
>
Upgrade Seats
</button>
<div className="flex items-center space-x-4">
<p className="text-lg font-medium text-gray-700">
Current Seats:
</p>
<p className="text-xl font-semibold text-gray-900">
{billingInformation.seats}
</p>
</div>
<p className="text-sm text-gray-500">
Seats automatically update based on adding, removing, or inviting
users.
</p>
</div>
)}
</div>
Expand Down
Loading

0 comments on commit e0f4e1f

Please sign in to comment.