diff --git a/backend/account/ReadMe.md b/backend/account/ReadMe.md deleted file mode 100644 index 35695cfb5..000000000 --- a/backend/account/ReadMe.md +++ /dev/null @@ -1,26 +0,0 @@ -# Basic WorkFlow - -`We can Add Workflows Here` - -## Login - -### Step - -1. Login -2. Get Organizations -3. Set Organization -4. Use organizational APIs /unstract// - -## Switch organization - -1. Get Organizations -2. Set Organization -3. Use organizational APIs /unstract// - -## Get current user and Organization data - -- Use Get User Profile and Get Organization Info APIs - -## Signout - -1.signout APi diff --git a/backend/account/__init__.py b/backend/account/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/backend/account/admin.py b/backend/account/admin.py deleted file mode 100644 index e0b96cce8..000000000 --- a/backend/account/admin.py +++ /dev/null @@ -1,5 +0,0 @@ -from django.contrib import admin - -from .models import Organization, User - -admin.site.register([Organization, User]) diff --git a/backend/account/apps.py b/backend/account/apps.py deleted file mode 100644 index 2c684a9eb..000000000 --- a/backend/account/apps.py +++ /dev/null @@ -1,6 +0,0 @@ -from django.apps import AppConfig - - -class AccountConfig(AppConfig): - default_auto_field = "django.db.models.BigAutoField" - name = "account" diff --git a/backend/account/authentication_controller.py b/backend/account/authentication_controller.py deleted file mode 100644 index f8b41e921..000000000 --- a/backend/account/authentication_controller.py +++ /dev/null @@ -1,474 +0,0 @@ -import logging -from typing import Any, Optional, Union - -from account.authentication_helper import AuthenticationHelper -from account.authentication_plugin_registry import AuthenticationPluginRegistry -from account.authentication_service import AuthenticationService -from account.constants import ( - AuthorizationErrorCode, - Common, - Cookie, - ErrorMessage, - OrganizationMemberModel, -) -from account.custom_exceptions import ( - DuplicateData, - Forbidden, - MethodNotImplemented, - UserNotExistError, -) -from account.dto import ( - MemberInvitation, - OrganizationData, - UserInfo, - UserInviteResponse, - UserRoleData, -) -from account.exceptions import OrganizationNotExist -from account.models import Organization, User -from account.organization import OrganizationService -from account.serializer import ( - GetOrganizationsResponseSerializer, - OrganizationSerializer, - SetOrganizationsResponseSerializer, -) -from account.user import UserService -from django.conf import settings -from django.contrib.auth import logout as django_logout -from django.db.utils import IntegrityError -from django.middleware import csrf -from django.shortcuts import redirect -from django_tenants.utils import tenant_context -from rest_framework import status -from rest_framework.request import Request -from rest_framework.response import Response -from tenant_account.models import OrganizationMember as OrganizationMember -from tenant_account.organization_member_service import OrganizationMemberService -from utils.cache_service import CacheService -from utils.local_context import StateStore -from utils.user_session import UserSessionUtils - -Logger = logging.getLogger(__name__) - - -class AuthenticationController: - """Authentication Controller This controller class manages user - authentication processes.""" - - def __init__(self) -> None: - """This method initializes the controller by selecting the appropriate - authentication plugin based on availability.""" - self.authentication_helper = AuthenticationHelper() - if AuthenticationPluginRegistry.is_plugin_available(): - self.auth_service: AuthenticationService = ( - AuthenticationPluginRegistry.get_plugin() - ) - else: - self.auth_service = AuthenticationService() - - def user_login( - self, - request: Request, - ) -> Any: - return self.auth_service.user_login(request) - - def user_signup(self, request: Request) -> Any: - return self.auth_service.user_signup(request) - - def authorization_callback( - self, request: Request, backend: str = settings.DEFAULT_MODEL_BACKEND - ) -> Any: - """Handle authorization callback. - - This function processes the authorization callback from - an external service. - - Args: - request (Request): Request instance - backend (str, optional): backend used to use login. - Defaults: settings.DEFAULT_MODEL_BACKEND. - - Returns: - Any: Redirect response - """ - try: - return self.auth_service.handle_authorization_callback( - request=request, backend=backend - ) - except Exception as ex: - Logger.error(f"Error while handling authorization callback: {ex}") - return redirect(f"{settings.ERROR_URL}") - - def user_organizations(self, request: Request) -> Any: - """List a user's organizations. - - Args: - user (User): User instance - z_code (str): _description_ - - Returns: - list[OrganizationData]: _description_ - """ - - try: - organizations = self.auth_service.user_organizations(request) - except Exception as ex: - # - self.user_logout(request) - - response = Response( - status=status.HTTP_412_PRECONDITION_FAILED, - ) - if hasattr(ex, "code") and ex.code in { - AuthorizationErrorCode.USF, - AuthorizationErrorCode.USR, - AuthorizationErrorCode.INE001, - AuthorizationErrorCode.INE002, - }: # type: ignore - response.data = ({"domain": ex.data.get("domain"), "code": ex.code},) - return response - # Return in case even if missed unknown exception in - # self.auth_service.user_organizations(request) - return response - - user: User = request.user - org_ids = {org.id for org in organizations} - - CacheService.set_user_organizations(user.user_id, list(org_ids)) - - serialized_organizations = GetOrganizationsResponseSerializer( - organizations, many=True - ).data - response = Response( - status=status.HTTP_200_OK, - data={ - "message": "success", - "organizations": serialized_organizations, - }, - ) - if Cookie.CSRFTOKEN not in request.COOKIES: - csrf_token = csrf.get_token(request) - response.set_cookie(Cookie.CSRFTOKEN, csrf_token) - - return response - - def set_user_organization(self, request: Request, organization_id: str) -> Response: - user: User = request.user - new_organization = False - organization_ids = CacheService.get_user_organizations(user.user_id) - if not organization_ids: - z_organizations: list[OrganizationData] = ( - self.auth_service.get_organizations_by_user_id(user.user_id) - ) - organization_ids = {org.id for org in z_organizations} - if organization_id and organization_id in organization_ids: - organization = OrganizationService.get_organization_by_org_id( - organization_id - ) - if not organization: - try: - organization_data: OrganizationData = ( - self.auth_service.get_organization_by_org_id(organization_id) - ) - except ValueError: - raise OrganizationNotExist() - try: - organization = OrganizationService.create_organization( - organization_data.name, - organization_data.display_name, - organization_data.id, - ) - new_organization = True - except IntegrityError: - raise DuplicateData( - f"{ErrorMessage.ORGANIZATION_EXIST}, \ - {ErrorMessage.DUPLICATE_API}" - ) - organization_member = self.create_tenant_user( - organization=organization, user=user - ) - - if new_organization: - try: - self.auth_service.hubspot_signup_api(request=request) - except MethodNotImplemented: - Logger.info("hubspot_signup_api not implemented") - - try: - self.auth_service.frictionless_onboarding( - organization=organization, user=user - ) - except MethodNotImplemented: - Logger.info("frictionless_onboarding not implemented") - - self.authentication_helper.create_initial_platform_key( - user=user, organization=organization - ) - - user_info: Optional[UserInfo] = self.get_user_info(request) - serialized_user_info = SetOrganizationsResponseSerializer(user_info).data - organization_info = OrganizationSerializer(organization).data - response: Response = Response( - status=status.HTTP_200_OK, - data={ - "is_new_org": new_organization, - "user": serialized_user_info, - "organization": organization_info, - f"{Common.LOG_EVENTS_ID}": StateStore.get(Common.LOG_EVENTS_ID), - }, - ) - current_organization_id = UserSessionUtils.get_organization_id(request) - if current_organization_id: - OrganizationMemberService.remove_user_membership_in_organization_cache( - user_id=user.user_id, - organization_id=current_organization_id, - ) - UserSessionUtils.set_organization_id(request, organization_id) - UserSessionUtils.set_organization_member_role(request, organization_member) - OrganizationMemberService.set_user_membership_in_organization_cache( - user_id=user.user_id, organization_id=organization_id - ) - return response - return Response(status=status.HTTP_403_FORBIDDEN) - - def get_user_info(self, request: Request) -> Optional[UserInfo]: - return self.auth_service.get_user_info(request) - - def is_admin_by_role(self, role: str) -> bool: - """Check the role is act as admin in the context of authentication - plugin. - - Args: - role (str): role - - Returns: - bool: _description_ - """ - return self.auth_service.is_admin_by_role(role=role) - - def get_organization_info(self, org_id: str) -> Optional[Organization]: - organization = OrganizationService.get_organization_by_org_id(org_id=org_id) - return organization - - def make_organization_and_add_member( - self, - user_id: str, - user_name: str, - organization_name: Optional[str] = None, - display_name: Optional[str] = None, - ) -> Optional[OrganizationData]: - return self.auth_service.make_organization_and_add_member( - user_id, user_name, organization_name, display_name - ) - - def make_user_organization_name(self) -> str: - return self.auth_service.make_user_organization_name() - - def make_user_organization_display_name(self, user_name: str) -> str: - return self.auth_service.make_user_organization_display_name(user_name) - - def user_logout(self, request: Request) -> Response: - response = self.auth_service.user_logout(request=request) - organization_id = UserSessionUtils.get_organization_id(request) - user_id = UserSessionUtils.get_user_id(request) - if organization_id: - OrganizationMemberService.remove_user_membership_in_organization_cache( - user_id=user_id, organization_id=organization_id - ) - django_logout(request) - return response - - def get_organization_members_by_org_id( - self, organization_id: Optional[str] = None - ) -> list[OrganizationMember]: - members: list[OrganizationMember] = OrganizationMember.objects.all() - return members - - def get_organization_members_by_user(self, user: User) -> OrganizationMember: - member: OrganizationMember = OrganizationMember.objects.filter( - user=user - ).first() - return member - - def get_user_roles(self) -> list[UserRoleData]: - return self.auth_service.get_roles() - - def get_user_invitations(self, organization_id: str) -> list[MemberInvitation]: - return self.auth_service.get_invitations(organization_id=organization_id) - - def delete_user_invitation(self, organization_id: str, invitation_id: str) -> bool: - return self.auth_service.delete_invitation( - organization_id=organization_id, invitation_id=invitation_id - ) - - def reset_user_password(self, user: User) -> Response: - return self.auth_service.reset_user_password(user) - - def invite_user( - self, - admin: User, - org_id: str, - user_list: list[dict[str, Union[str, None]]], - ) -> list[UserInviteResponse]: - """Invites users to join an organization. - - Args: - admin (User): Admin user initiating the invitation. - org_id (str): ID of the organization to which users are invited. - user_list (list[dict[str, Union[str, None]]]): - List of user details for invitation. - Returns: - list[UserInviteResponse]: List of responses for each - user invitation. - """ - admin_user = OrganizationMember.objects.get(user=admin.id) - if not self.auth_service.is_organization_admin(admin_user): - raise Forbidden() - response = [] - for user_item in user_list: - email = user_item.get("email") - role = user_item.get("role") - if email: - user = OrganizationMemberService.get_user_by_email(email=email) - user_response = {} - user_response["email"] = email - status = False - message = "User is already part of current organization" - # Check if user is already part of current organization - if not user: - status = self.auth_service.invite_user( - admin_user, org_id, email, role=role - ) - message = "User invitation successful." - - response.append( - UserInviteResponse( - email=email, - status="success" if status else "failed", - message=message, - ) - ) - return response - - def remove_users_from_organization( - self, admin: User, organization_id: str, user_emails: list[str] - ) -> bool: - admin_user = OrganizationMember.objects.get(user=admin.id) - user_ids = OrganizationMember.objects.filter( - user__email__in=user_emails - ).values_list(OrganizationMemberModel.USER_ID, OrganizationMemberModel.ID) - user_ids_list: list[str] = [] - pk_list: list[str] = [] - for user in user_ids: - user_ids_list.append(user[0]) - pk_list.append(user[1]) - if len(user_ids_list) > 0: - is_removed = self.auth_service.remove_users_from_organization( - admin=admin_user, - organization_id=organization_id, - user_ids=user_ids_list, - ) - else: - is_removed = False - if is_removed: - AuthenticationHelper.remove_users_from_organization_by_pks(pk_list) - for user_id in user_ids_list: - OrganizationMemberService.remove_user_membership_in_organization_cache( - user_id, organization_id - ) - - return is_removed - - def add_user_role( - self, admin: User, org_id: str, email: str, role: str - ) -> Optional[str]: - admin_user = OrganizationMember.objects.get(user=admin.id) - user = OrganizationMemberService.get_user_by_email(email=email) - if user: - current_roles = self.auth_service.add_organization_user_role( - admin_user, org_id, user.user.user_id, [role] - ) - if current_roles: - self.save_orgnanization_user_role( - user_id=user.user.user_id, role=current_roles[0] - ) - return current_roles[0] - else: - return None - - def remove_user_role( - self, admin: User, org_id: str, email: str, role: str - ) -> Optional[str]: - admin_user = OrganizationMember.objects.get(user=admin.id) - organization_member = OrganizationMemberService.get_user_by_email(email=email) - if organization_member: - current_roles = self.auth_service.remove_organization_user_role( - admin_user, org_id, organization_member.user.user_id, [role] - ) - if current_roles: - self.save_orgnanization_user_role( - user_id=organization_member.user.user_id, - role=current_roles[0], - ) - return current_roles[0] - else: - return None - - def save_orgnanization_user_role(self, user_id: str, role: str) -> None: - organization_user = OrganizationMemberService.get_user_by_user_id( - user_id=user_id - ) - if organization_user: - # consider single role - organization_user.role = role - organization_user.save() - - def create_tenant_user( - self, organization: Organization, user: User - ) -> OrganizationMember: - with tenant_context(organization): - existing_tenant_user = OrganizationMemberService.get_user_by_id(id=user.id) - if existing_tenant_user: - Logger.info(f"{existing_tenant_user.user.email} Already exist") - return existing_tenant_user - else: - account_user = self.get_or_create_user(user=user) - if account_user: - user_roles = self.auth_service.get_organization_role_of_user( - user_id=account_user.user_id, - organization_id=organization.organization_id, - ) - user_role = user_roles[0] - - tenant_user: OrganizationMember = OrganizationMember( - user=user, - role=user_role, - is_login_onboarding_msg=False, - is_prompt_studio_onboarding_msg=False, - ) - tenant_user.save() - return tenant_user - else: - raise UserNotExistError() - - def get_or_create_user( - self, user: User - ) -> Optional[Union[User, OrganizationMember]]: - user_service = UserService() - if user.id: - account_user: Optional[User] = user_service.get_user_by_id(user.id) - if account_user: - return account_user - elif user.email: - account_user = user_service.get_user_by_email(email=user.email) - if account_user: - return account_user - if user.user_id: - user.save() - return user - elif user.email and user.user_id: - account_user = user_service.create_user( - email=user.email, user_id=user.user_id - ) - return account_user - return None diff --git a/backend/account/authentication_helper.py b/backend/account/authentication_helper.py deleted file mode 100644 index 928624e07..000000000 --- a/backend/account/authentication_helper.py +++ /dev/null @@ -1,121 +0,0 @@ -import logging -from typing import Any - -from account.dto import MemberData -from account.models import Organization, User -from account.user import UserService -from platform_settings.platform_auth_service import PlatformAuthenticationService -from tenant_account.organization_member_service import OrganizationMemberService - -logger = logging.getLogger(__name__) - - -class AuthenticationHelper: - def __init__(self) -> None: - pass - - def list_of_members_from_user_model( - self, model_data: list[Any] - ) -> list[MemberData]: - members: list[MemberData] = [] - for data in model_data: - user_id = data.user_id - email = data.email - name = data.username - - members.append(MemberData(user_id=user_id, email=email, name=name)) - - return members - - @staticmethod - def get_or_create_user_by_email(user_id: str, email: str) -> User: - """Get or create a user with the given email. - - If a user with the given email already exists, return that user. - Otherwise, create a new user with the given email and return it. - - Parameters: - user_id (str): The ID of the user. - email (str): The email of the user. - - Returns: - User: The user with the given email. - """ - user_service = UserService() - user = user_service.get_user_by_email(email) - if user and not user.user_id: - user = user_service.update_user(user, user_id) - if not user: - user = user_service.create_user(email, user_id) - return user - - def create_initial_platform_key( - self, user: User, organization: Organization - ) -> None: - """Create an initial platform key for the given user and organization. - - This method generates a new platform key with the specified parameters - and saves it to the database. The generated key is set as active and - assigned the name "Key #1". The key is associated with the provided - user and organization. - - Parameters: - user (User): The user for whom the platform key is being created. - organization (Organization): - The organization to which the platform key belongs. - - Raises: - Exception: If an error occurs while generating the platform key. - - Returns: - None - """ - try: - PlatformAuthenticationService.generate_platform_key( - is_active=True, - key_name="Key #1", - user=user, - organization=organization, - ) - except Exception: - logger.error( - "Failed to create default platform key for " - f"organization {organization.organization_id}" - ) - - @staticmethod - def remove_users_from_organization_by_pks( - user_pks: list[str], - ) -> None: - """Remove users from an organization by their primary keys. - - Parameters: - user_pks (list[str]): The primary keys of the users to remove. - """ - # removing user from organization - OrganizationMemberService.remove_users_by_user_pks(user_pks) - # removing user m2m relations , while removing user - for user_pk in user_pks: - User.objects.get(pk=user_pk).shared_exported_tools.clear() - User.objects.get(pk=user_pk).shared_custom_tool.clear() - User.objects.get(pk=user_pk).shared_adapters.clear() - - @staticmethod - def remove_user_from_organization_by_user_id( - user_id: str, organization_id: str - ) -> None: - """Remove users from an organization by their user_id. - - Parameters: - user_id (str): The user_id of the users to remove. - """ - # removing user from organization - OrganizationMemberService.remove_user_by_user_id(user_id) - # removing user m2m relations , while removing user - User.objects.get(user_id=user_id).shared_exported_tools.clear() - User.objects.get(user_id=user_id).shared_custom_tool.clear() - User.objects.get(user_id=user_id).shared_adapters.clear() - # removing user from organization cache - OrganizationMemberService.remove_user_membership_in_organization_cache( - user_id=user_id, organization_id=organization_id - ) diff --git a/backend/account/authentication_plugin_registry.py b/backend/account/authentication_plugin_registry.py deleted file mode 100644 index cd630fdaf..000000000 --- a/backend/account/authentication_plugin_registry.py +++ /dev/null @@ -1,96 +0,0 @@ -import logging -import os -from importlib import import_module -from typing import Any - -from account.constants import PluginConfig -from django.apps import apps - -Logger = logging.getLogger(__name__) - - -def _load_plugins() -> dict[str, dict[str, Any]]: - """Iterating through the Authentication plugins and register their - metadata.""" - auth_app = apps.get_app_config(PluginConfig.PLUGINS_APP) - auth_package_path = auth_app.module.__package__ - auth_dir = os.path.join(auth_app.path, PluginConfig.AUTH_PLUGIN_DIR) - auth_package_path = f"{auth_package_path}.{PluginConfig.AUTH_PLUGIN_DIR}" - auth_modules = {} - - for item in os.listdir(auth_dir): - # Loads a plugin only if name starts with `auth`. - if not item.startswith(PluginConfig.AUTH_MODULE_PREFIX): - continue - # Loads a plugin if it is in a directory. - if os.path.isdir(os.path.join(auth_dir, item)): - auth_module_name = item - # Loads a plugin if it is a shared library. - # Module name is extracted from shared library name. - # `auth.platform_architecture.so` will be file name and - # `auth` will be the module name. - elif item.endswith(".so"): - auth_module_name = item.split(".")[0] - else: - continue - try: - full_module_path = f"{auth_package_path}.{auth_module_name}" - module = import_module(full_module_path) - metadata = getattr(module, PluginConfig.AUTH_METADATA, {}) - if metadata.get(PluginConfig.METADATA_IS_ACTIVE, False): - auth_modules[auth_module_name] = { - PluginConfig.AUTH_MODULE: module, - PluginConfig.AUTH_METADATA: module.metadata, - } - Logger.info( - "Loaded auth plugin: %s, is_active: %s", - module.metadata["name"], - module.metadata["is_active"], - ) - else: - Logger.warning( - "Metadata is not active for %s authentication module.", - auth_module_name, - ) - except ModuleNotFoundError as exception: - Logger.error( - "Error while importing authentication module : %s", - exception, - ) - - if len(auth_modules) > 1: - raise ValueError( - "Multiple authentication modules found." - "Only one authentication method is allowed." - ) - elif len(auth_modules) == 0: - Logger.warning( - "No authentication modules found." - "Application will start without authentication module" - ) - return auth_modules - - -class AuthenticationPluginRegistry: - auth_modules: dict[str, dict[str, Any]] = _load_plugins() - - @classmethod - def is_plugin_available(cls) -> bool: - """Check if any authentication plugin is available. - - Returns: - bool: True if a plugin is available, False otherwise. - """ - return len(cls.auth_modules) > 0 - - @classmethod - def get_plugin(cls) -> Any: - """Get the selected authentication plugin. - - Returns: - AuthenticationService: Selected authentication plugin instance. - """ - chosen_auth_module = next(iter(cls.auth_modules.values())) - chosen_metadata = chosen_auth_module[PluginConfig.AUTH_METADATA] - service_class_name = chosen_metadata[PluginConfig.METADATA_SERVICE_CLASS] - return service_class_name() diff --git a/backend/account/authentication_service.py b/backend/account/authentication_service.py deleted file mode 100644 index d7a98eea8..000000000 --- a/backend/account/authentication_service.py +++ /dev/null @@ -1,394 +0,0 @@ -import logging -import uuid -from typing import Any, Optional - -from account.authentication_helper import AuthenticationHelper -from account.constants import DefaultOrg, ErrorMessage, UserLoginTemplate -from account.custom_exceptions import Forbidden, MethodNotImplemented -from account.dto import ( - CallbackData, - MemberData, - MemberInvitation, - OrganizationData, - ResetUserPasswordDto, - UserInfo, - UserRoleData, -) -from account.enums import UserRole -from account.models import Organization, User -from account.organization import OrganizationService -from account.serializer import LoginRequestSerializer -from django.conf import settings -from django.contrib.auth import authenticate, login, logout -from django.contrib.auth.hashers import make_password -from django.http import HttpRequest -from django.shortcuts import redirect, render -from rest_framework.request import Request -from rest_framework.response import Response -from tenant_account.models import OrganizationMember as OrganizationMember - -Logger = logging.getLogger(__name__) - - -class AuthenticationService: - def __init__(self) -> None: - self.authentication_helper = AuthenticationHelper() - self.default_organization: Organization = self.user_organization() - - def user_login(self, request: Request) -> Any: - """Authenticate and log in a user. - - Args: - request (Request): The HTTP request object. - - Returns: - Any: The response object. - - Raises: - ValueError: If there is an error in the login credentials. - """ - if request.method == "GET": - return self.render_login_page(request) - try: - validated_data = self.validate_login_credentials(request) - username = validated_data.get("username") - password = validated_data.get("password") - except ValueError as e: - return render( - request, - UserLoginTemplate.TEMPLATE, - {UserLoginTemplate.ERROR_PLACE_HOLDER: str(e)}, - ) - if self.authenticate_and_login(request, username, password): - return redirect(settings.WEB_APP_ORIGIN_URL) - - return self.render_login_page_with_error(request, ErrorMessage.USER_LOGIN_ERROR) - - def is_authenticated(self, request: HttpRequest) -> bool: - """Check if the user is authenticated. - - Args: - request (Request): The HTTP request object. - - Returns: - bool: True if the user is authenticated, False otherwise. - """ - return request.user.is_authenticated - - def authenticate_and_login( - self, request: Request, username: str, password: str - ) -> bool: - """Authenticate and log in a user. - - Args: - request (Request): The HTTP request object. - username (str): The username of the user. - password (str): The password of the user. - - Returns: - bool: True if the user is successfully authenticated and logged in, - False otherwise. - """ - user = authenticate(request, username=username, password=password) - if user: - # To avoid conflicts with django superuser - if user.is_superuser: - return False - login(request, user) - return True - # Attempt to initiate default user and authenticate again - if self.set_default_user(username, password): - user = authenticate(request, username=username, password=password) - if user: - login(request, user) - return True - return False - - def render_login_page(self, request: Request) -> Any: - return render(request, UserLoginTemplate.TEMPLATE) - - def render_login_page_with_error(self, request: Request, error_message: str) -> Any: - return render( - request, - UserLoginTemplate.TEMPLATE, - {UserLoginTemplate.ERROR_PLACE_HOLDER: error_message}, - ) - - def validate_login_credentials(self, request: Request) -> Any: - """Validate the login credentials. - - Args: - request (Request): The HTTP request object. - - Returns: - dict: The validated login credentials. - - Raises: - ValueError: If the login credentials are invalid. - """ - serializer = LoginRequestSerializer(data=request.POST) - if not serializer.is_valid(): - error_messages = { - field: errors[0] for field, errors in serializer.errors.items() - } - first_error_message = list(error_messages.values())[0] - raise ValueError(first_error_message) - return serializer.validated_data - - def user_signup(self, request: HttpRequest) -> Any: - raise MethodNotImplemented() - - def is_admin_by_role(self, role: str) -> bool: - """Check the role with actual admin Role. - - Args: - role (str): input string - - Returns: - bool: _description_ - """ - try: - return UserRole(role.lower()) == UserRole.ADMIN - except ValueError: - return False - - def get_callback_data(self, request: Request) -> CallbackData: - return CallbackData( - user_id=request.user.user_id, - email=request.user.email, - token="", - ) - - def user_organization(self) -> Organization: - return Organization( - name=DefaultOrg.ORGANIZATION_NAME, - display_name=DefaultOrg.ORGANIZATION_NAME, - organization_id=DefaultOrg.ORGANIZATION_NAME, - schema_name=DefaultOrg.ORGANIZATION_NAME, - ) - - def handle_invited_user_while_callback( - self, request: Request, user: User - ) -> MemberData: - member_data: MemberData = MemberData( - user_id=user.user_id, - organization_id=self.default_organization.organization_id, - role=[UserRole.ADMIN.value], - ) - - return member_data - - def handle_authorization_callback(self, request: Request, backend: str) -> Response: - raise MethodNotImplemented() - - def add_to_organization( - self, - request: Request, - user: User, - data: Optional[dict[str, Any]] = None, - ) -> MemberData: - member_data: MemberData = MemberData( - user_id=user.user_id, - organization_id=self.default_organization.organization_id, - ) - - return member_data - - def remove_users_from_organization( - self, - admin: OrganizationMember, - organization_id: str, - user_ids: list[str], - ) -> bool: - raise MethodNotImplemented() - - def user_organizations(self, request: Request) -> list[OrganizationData]: - organizationData: OrganizationData = OrganizationData( - id=self.default_organization.organization_id, - display_name=self.default_organization.display_name, - name=self.default_organization.name, - ) - return [organizationData] - - def get_organizations_by_user_id(self, id: str) -> list[OrganizationData]: - organizationData: OrganizationData = OrganizationData( - id=self.default_organization.organization_id, - display_name=self.default_organization.display_name, - name=self.default_organization.name, - ) - return [organizationData] - - def get_organization_role_of_user( - self, user_id: str, organization_id: str - ) -> list[str]: - return [UserRole.ADMIN.value] - - def is_organization_admin(self, member: OrganizationMember) -> bool: - """Check if the organization member has administrative privileges. - - Args: - member (OrganizationMember): The organization member to check. - - Returns: - bool: True if the user has administrative privileges, - False otherwise. - """ - try: - return UserRole(member.role) == UserRole.ADMIN - except ValueError: - return False - - def check_user_organization_association(self, user_email: str) -> None: - """Check if the user is already associated with any organizations. - - Raises: - - UserAlreadyAssociatedException: - If the user is already associated with organizations. - """ - return None - - def get_roles(self) -> list[UserRoleData]: - return [ - UserRoleData(name=UserRole.ADMIN.value), - UserRoleData(name=UserRole.USER.value), - ] - - def get_invitations(self, organization_id: str) -> list[MemberInvitation]: - raise MethodNotImplemented() - - def frictionless_onboarding(self, organization: Organization, user: User) -> None: - raise MethodNotImplemented() - - def hubspot_signup_api(self, request: Request) -> None: - raise MethodNotImplemented() - - def delete_invitation(self, organization_id: str, invitation_id: str) -> bool: - raise MethodNotImplemented() - - def add_organization_user_role( - self, - admin: User, - organization_id: str, - user_id: str, - role_ids: list[str], - ) -> list[str]: - if admin.role == UserRole.ADMIN.value: - return role_ids - raise Forbidden - - def remove_organization_user_role( - self, - admin: User, - organization_id: str, - user_id: str, - role_ids: list[str], - ) -> list[str]: - if admin.role == UserRole.ADMIN.value: - return role_ids - raise Forbidden - - def get_organization_by_org_id(self, id: str) -> OrganizationData: - organizationData: OrganizationData = OrganizationData( - id=DefaultOrg.ORGANIZATION_NAME, - display_name=DefaultOrg.ORGANIZATION_NAME, - name=DefaultOrg.ORGANIZATION_NAME, - ) - return organizationData - - def set_default_user(self, username: str, password: str) -> bool: - """Set the default user for authentication. - - This method creates a default user with the provided username and - password if the username and password match the default values defined - in the 'DefaultOrg' class. The default user is saved in the database. - - Args: - username (str): The username of the default user. - password (str): The password of the default user. - - Returns: - bool: True if the default user is successfully created and saved, - False otherwise. - """ - if ( - username != DefaultOrg.MOCK_USER - or password != DefaultOrg.MOCK_USER_PASSWORD - ): - return False - - user, created = User.objects.get_or_create(username=DefaultOrg.MOCK_USER) - if created: - user.password = make_password(DefaultOrg.MOCK_USER_PASSWORD) - else: - user.user_id = DefaultOrg.MOCK_USER_ID - user.email = DefaultOrg.MOCK_USER_EMAIL - user.password = make_password(DefaultOrg.MOCK_USER_PASSWORD) - user.save() - return True - - def get_user_info(self, request: Request) -> Optional[UserInfo]: - user: User = request.user - if user: - return UserInfo( - id=user.id, - user_id=user.user_id, - name=user.username, - display_name=user.username, - email=user.email, - ) - else: - return None - - def get_organization_info(self, org_id: str) -> Optional[Organization]: - return OrganizationService.get_organization_by_org_id(org_id=org_id) - - def make_organization_and_add_member( - self, - user_id: str, - user_name: str, - organization_name: Optional[str] = None, - display_name: Optional[str] = None, - ) -> Optional[OrganizationData]: - organization: OrganizationData = OrganizationData( - id=str(uuid.uuid4()), - display_name=DefaultOrg.MOCK_ORG, - name=DefaultOrg.MOCK_ORG, - ) - return organization - - def make_user_organization_name(self) -> str: - return str(uuid.uuid4()) - - def make_user_organization_display_name(self, user_name: str) -> str: - name = f"{user_name}'s" if user_name else "Your" - return f"{name} organization" - - def user_logout(self, request: HttpRequest) -> Response: - """Log out the user. - - Args: - request (HttpRequest): The HTTP request object. - - Returns: - Response: The redirect response to the web app origin URL. - """ - logout(request) - return redirect(settings.WEB_APP_ORIGIN_URL) - - def get_organization_members_by_org_id( - self, organization_id: str - ) -> list[MemberData]: - users: list[OrganizationMember] = OrganizationMember.objects.all() - return self.authentication_helper.list_of_members_from_user_model(users) - - def reset_user_password(self, user: User) -> ResetUserPasswordDto: - raise MethodNotImplemented() - - def invite_user( - self, - admin: OrganizationMember, - org_id: str, - email: str, - role: Optional[str] = None, - ) -> bool: - raise MethodNotImplemented() diff --git a/backend/account/constants.py b/backend/account/constants.py deleted file mode 100644 index 15c113790..000000000 --- a/backend/account/constants.py +++ /dev/null @@ -1,89 +0,0 @@ -from django.conf import settings - - -class LoginConstant: - INVITATION = "invitation" - ORGANIZATION = "organization" - ORGANIZATION_NAME = "organization_name" - - -class Common: - NEXT_URL_VARIABLE = "next" - PUBLIC_SCHEMA_NAME = "public" - ID = "id" - USER_ID = "user_id" - USER_EMAIL = "email" - USER_EMAILS = "emails" - USER_IDS = "user_ids" - USER_ROLE = "role" - MAX_EMAIL_IN_REQUEST = 10 - LOG_EVENTS_ID = "log_events_id" - - -class UserModel: - USER_ID = "user_id" - ID = "id" - - -class OrganizationMemberModel: - USER_ID = "user__user_id" - ID = "user__id" - - -class Cookie: - ORG_ID = "org_id" - Z_CODE = "z_code" - CSRFTOKEN = "csrftoken" - - -class ErrorMessage: - ORGANIZATION_EXIST = "Organization already exists" - DUPLICATE_API = "It appears that a duplicate call may have been made." - USER_LOGIN_ERROR = "Invalid username or password. Please try again." - - -class DefaultOrg: - ORGANIZATION_NAME = "mock_org" - MOCK_ORG = "mock_org" - MOCK_USER = settings.DEFAULT_AUTH_USERNAME - MOCK_USER_ID = "mock_user_id" - MOCK_USER_EMAIL = "email@mock.com" - MOCK_USER_PASSWORD = settings.DEFAULT_AUTH_PASSWORD - - -class UserLoginTemplate: - TEMPLATE = "login.html" - ERROR_PLACE_HOLDER = "error_message" - - -class PluginConfig: - PLUGINS_APP = "plugins" - AUTH_MODULE_PREFIX = "auth" - AUTH_PLUGIN_DIR = "authentication" - AUTH_MODULE = "module" - AUTH_METADATA = "metadata" - METADATA_SERVICE_CLASS = "service_class" - METADATA_IS_ACTIVE = "is_active" - - -class AuthorizationErrorCode: - """Error codes - IDM: INVITATION DENIED MESSAGE (Unauthorized invitation) - INF: INVITATION NOT FOUND (Invitation is either invalid or has expired) - UMM: USER MEMBERSHIP MISCONDUCT - USF: USER FOUND (User Account Already Exists for Organization) - INE001: INVALID EMAIL Exception code when an invalid email address is used - like disposable. - INE002: INVALID EMAIL Exception code when an invalid email address format. - - Error code reference : - frontend/src/components/error/GenericError/GenericError.jsx. - """ - - IDM = "IDM" - UMM = "UMM" - INF = "INF" - USF = "USF" - USR = "USR" - INE001 = "INE001" - INE002 = "INE002" diff --git a/backend/account/custom_auth_middleware.py b/backend/account/custom_auth_middleware.py deleted file mode 100644 index 020f5f413..000000000 --- a/backend/account/custom_auth_middleware.py +++ /dev/null @@ -1,46 +0,0 @@ -from account.authentication_plugin_registry import AuthenticationPluginRegistry -from account.authentication_service import AuthenticationService -from account.constants import Common -from django.conf import settings -from django.http import HttpRequest, HttpResponse, JsonResponse -from utils.local_context import StateStore -from utils.user_session import UserSessionUtils - -from backend.constants import RequestHeader - - -class CustomAuthMiddleware: - def __init__(self, get_response: HttpResponse): - self.get_response = get_response - # One-time configuration and initialization. - - def __call__(self, request: HttpRequest) -> HttpResponse: - # Returns result without authenticated if added in whitelisted paths - if any(request.path.startswith(path) for path in settings.WHITELISTED_PATHS): - return self.get_response(request) - - # Authenticating With API_KEY - x_api_key = request.headers.get(RequestHeader.X_API_KEY) - if ( - settings.INTERNAL_SERVICE_API_KEY - and x_api_key == settings.INTERNAL_SERVICE_API_KEY - ): # Should API Key be in settings or just env alone? - return self.get_response(request) - - if AuthenticationPluginRegistry.is_plugin_available(): - auth_service: AuthenticationService = ( - AuthenticationPluginRegistry.get_plugin() - ) - else: - auth_service = AuthenticationService() - - is_authenticated = auth_service.is_authenticated(request) - is_authorized = UserSessionUtils.is_authorized_path(request) - - if is_authenticated and is_authorized: - StateStore.set(Common.LOG_EVENTS_ID, request.session.session_key) - response = self.get_response(request) - StateStore.clear(Common.LOG_EVENTS_ID) - - return response - return JsonResponse({"message": "Unauthorized"}, status=401) diff --git a/backend/account/custom_authentication.py b/backend/account/custom_authentication.py deleted file mode 100644 index 1f7cdcb6e..000000000 --- a/backend/account/custom_authentication.py +++ /dev/null @@ -1,13 +0,0 @@ -from typing import Any - -from django.http import HttpRequest -from rest_framework.exceptions import AuthenticationFailed - - -def api_login_required(view_func: Any) -> Any: - def wrapper(request: HttpRequest, *args: Any, **kwargs: Any) -> Any: - if request.user and request.session and "user" in request.session: - return view_func(request, *args, **kwargs) - raise AuthenticationFailed("Unauthorized") - - return wrapper diff --git a/backend/account/custom_cache.py b/backend/account/custom_cache.py deleted file mode 100644 index 182f980f5..000000000 --- a/backend/account/custom_cache.py +++ /dev/null @@ -1,12 +0,0 @@ -from django_redis import get_redis_connection - - -class CustomCache: - def __init__(self) -> None: - self.cache = get_redis_connection("default") - - def rpush(self, key: str, value: str) -> None: - self.cache.rpush(key, value) - - def lrem(self, key: str, value: str) -> None: - self.cache.lrem(key, value) diff --git a/backend/account/custom_exceptions.py b/backend/account/custom_exceptions.py deleted file mode 100644 index bec24e16c..000000000 --- a/backend/account/custom_exceptions.py +++ /dev/null @@ -1,60 +0,0 @@ -from typing import Optional - -from rest_framework.exceptions import APIException - - -class ConflictError(Exception): - def __init__(self, message: str) -> None: - self.message = message - super().__init__(self.message) - - -class MethodNotImplemented(APIException): - status_code = 501 - default_detail = "Method Not Implemented" - - -class DuplicateData(APIException): - status_code = 400 - default_detail = "Duplicate Data" - - def __init__(self, detail: Optional[str] = None, code: Optional[int] = None): - if detail is not None: - self.detail = detail - if code is not None: - self.code = code - super().__init__(detail, code) - - -class TableNotExistError(APIException): - status_code = 400 - default_detail = "Unknown Table" - - def __init__(self, detail: Optional[str] = None, code: Optional[int] = None): - if detail is not None: - self.detail = detail - if code is not None: - self.code = code - super().__init__() - - -class UserNotExistError(APIException): - status_code = 400 - default_detail = "Unknown User" - - def __init__(self, detail: Optional[str] = None, code: Optional[int] = None): - if detail is not None: - self.detail = detail - if code is not None: - self.code = code - super().__init__() - - -class Forbidden(APIException): - status_code = 403 - default_detail = "Do not have permission to perform this action." - - -class UserAlreadyAssociatedException(APIException): - status_code = 400 - default_detail = "User is already associated with one organization." diff --git a/backend/account/dto.py b/backend/account/dto.py deleted file mode 100644 index 66cf6c1aa..000000000 --- a/backend/account/dto.py +++ /dev/null @@ -1,134 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Optional - - -@dataclass -class MemberData: - user_id: str - email: Optional[str] = None - name: Optional[str] = None - picture: Optional[str] = None - role: Optional[list[str]] = None - organization_id: Optional[str] = None - - -@dataclass -class OrganizationData: - id: str - display_name: str - name: str - - -@dataclass -class CallbackData: - user_id: str - email: str - token: Any - - -@dataclass -class OrganizationSignupRequestBody: - name: str - display_name: str - organization_id: str - - -@dataclass -class OrganizationSignupResponse: - name: str - display_name: str - organization_id: str - created_at: str - - -@dataclass -class UserInfo: - email: str - user_id: str - id: Optional[str] = None - name: Optional[str] = None - display_name: Optional[str] = None - family_name: Optional[str] = None - picture: Optional[str] = None - - -@dataclass -class UserSessionInfo: - id: str - user_id: str - email: str - organization_id: str - user: UserInfo - role: str - - @staticmethod - def from_dict(data: dict[str, Any]) -> "UserSessionInfo": - return UserSessionInfo( - id=data["id"], - user_id=data["user_id"], - email=data["email"], - organization_id=data["organization_id"], - role=data["role"], - ) - - def to_dict(self) -> Any: - return { - "id": self.id, - "user_id": self.user_id, - "email": self.email, - "organization_id": self.organization_id, - "role": self.role, - } - - -@dataclass -class GetUserReposne: - user: UserInfo - organizations: list[OrganizationData] - - -@dataclass -class ResetUserPasswordDto: - status: bool - message: str - - -@dataclass -class UserInviteResponse: - email: str - status: str - message: Optional[str] = None - - -@dataclass -class UserRoleData: - name: str - id: Optional[str] = None - description: Optional[str] = None - - -@dataclass -class MemberInvitation: - """Represents an invitation to join an organization. - - Attributes: - id (str): The unique identifier for the invitation. - email (str): The user email. - roles (List[str]): The roles assigned to the invitee. - created_at (Optional[str]): The timestamp when the invitation - was created. - expires_at (Optional[str]): The timestamp when the invitation expires. - """ - - id: str - email: str - roles: list[str] - created_at: Optional[str] = None - expires_at: Optional[str] = None - - -@dataclass -class UserOrganizationRole: - user_id: str - role: UserRoleData - organization_id: str diff --git a/backend/account/enums.py b/backend/account/enums.py deleted file mode 100644 index d8209ec2d..000000000 --- a/backend/account/enums.py +++ /dev/null @@ -1,6 +0,0 @@ -from enum import Enum - - -class UserRole(Enum): - USER = "user" - ADMIN = "admin" diff --git a/backend/account/exceptions.py b/backend/account/exceptions.py deleted file mode 100644 index 9f0b443b6..000000000 --- a/backend/account/exceptions.py +++ /dev/null @@ -1,26 +0,0 @@ -from rest_framework.exceptions import APIException - - -class UserIdNotExist(APIException): - status_code = 404 - default_detail = "User ID does not exist" - - -class UserAlreadyExistInOrganization(APIException): - status_code = 403 - default_detail = "User allready exist in the organization" - - -class OrganizationNotExist(APIException): - status_code = 404 - default_detail = "Organization does not exist" - - -class UnknownException(APIException): - status_code = 500 - default_detail = "An unexpected error occurred" - - -class BadRequestException(APIException): - status_code = 400 - default_detail = "Bad Request" diff --git a/backend/account/migrations/0001_initial.py b/backend/account/migrations/0001_initial.py deleted file mode 100644 index e8622577b..000000000 --- a/backend/account/migrations/0001_initial.py +++ /dev/null @@ -1,244 +0,0 @@ -# Generated by Django 4.2.1 on 2023-07-18 10:39 - -import django.contrib.auth.models -import django.contrib.auth.validators -import django.db.models.deletion -import django.utils.timezone -import django_tenants.postgresql_backend.base -from django.conf import settings -from django.db import migrations, models - - -class Migration(migrations.Migration): - initial = True - - dependencies = [ - ("auth", "0012_alter_user_first_name_max_length"), - ] - - operations = [ - migrations.CreateModel( - name="User", - fields=[ - ( - "id", - models.BigAutoField( - auto_created=True, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - ), - ( - "password", - models.CharField(max_length=128, verbose_name="password"), - ), - ( - "last_login", - models.DateTimeField( - blank=True, null=True, verbose_name="last login" - ), - ), - ( - "is_superuser", - models.BooleanField( - default=False, - help_text="Designates that this user has all permissions without explicitly assigning them.", - verbose_name="superuser status", - ), - ), - ( - "username", - models.CharField( - error_messages={ - "unique": "A user with that username already exists." - }, - help_text="Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.", - max_length=150, - unique=True, - validators=[ - django.contrib.auth.validators.UnicodeUsernameValidator() - ], - verbose_name="username", - ), - ), - ( - "first_name", - models.CharField( - blank=True, max_length=150, verbose_name="first name" - ), - ), - ( - "last_name", - models.CharField( - blank=True, max_length=150, verbose_name="last name" - ), - ), - ( - "email", - models.EmailField( - blank=True, max_length=254, verbose_name="email address" - ), - ), - ( - "is_staff", - models.BooleanField( - default=False, - help_text="Designates whether the user can log into this admin site.", - verbose_name="staff status", - ), - ), - ( - "is_active", - models.BooleanField( - default=True, - help_text="Designates whether this user should be treated as active. Unselect this instead of deleting accounts.", - verbose_name="active", - ), - ), - ( - "date_joined", - models.DateTimeField( - default=django.utils.timezone.now, - verbose_name="date joined", - ), - ), - ("user_id", models.CharField()), - ("project_storage_created", models.BooleanField(default=False)), - ("modified_at", models.DateTimeField(auto_now=True)), - ("created_at", models.DateTimeField(auto_now_add=True)), - ( - "created_by", - models.ForeignKey( - blank=True, - null=True, - on_delete=django.db.models.deletion.SET_NULL, - related_name="created_users", - to=settings.AUTH_USER_MODEL, - ), - ), - ( - "groups", - models.ManyToManyField( - blank=True, - related_name="customuser_set", - related_query_name="customuser", - to="auth.group", - ), - ), - ( - "modified_by", - models.ForeignKey( - blank=True, - null=True, - on_delete=django.db.models.deletion.SET_NULL, - related_name="modified_users", - to=settings.AUTH_USER_MODEL, - ), - ), - ( - "user_permissions", - models.ManyToManyField( - blank=True, - related_name="customuser_set", - related_query_name="customuser", - to="auth.permission", - ), - ), - ], - options={ - "verbose_name": "user", - "verbose_name_plural": "users", - "abstract": False, - }, - managers=[ - ("objects", django.contrib.auth.models.UserManager()), - ], - ), - migrations.CreateModel( - name="Organization", - fields=[ - ( - "id", - models.BigAutoField( - auto_created=True, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - ), - ( - "schema_name", - models.CharField( - db_index=True, - max_length=63, - unique=True, - validators=[ - django_tenants.postgresql_backend.base._check_schema_name - ], - ), - ), - ("name", models.CharField(max_length=64)), - ("display_name", models.CharField(max_length=64)), - ("organization_id", models.CharField(max_length=64)), - ("modified_at", models.DateTimeField(auto_now=True)), - ("created_at", models.DateTimeField(auto_now=True)), - ( - "created_by", - models.ForeignKey( - blank=True, - null=True, - on_delete=django.db.models.deletion.SET_NULL, - related_name="created_orgs", - to=settings.AUTH_USER_MODEL, - ), - ), - ( - "modified_by", - models.ForeignKey( - blank=True, - null=True, - on_delete=django.db.models.deletion.SET_NULL, - related_name="modified_orgs", - to=settings.AUTH_USER_MODEL, - ), - ), - ], - options={ - "abstract": False, - }, - ), - migrations.CreateModel( - name="Domain", - fields=[ - ( - "id", - models.BigAutoField( - auto_created=True, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - ), - ( - "domain", - models.CharField(db_index=True, max_length=253, unique=True), - ), - ( - "is_primary", - models.BooleanField(db_index=True, default=True), - ), - ( - "tenant", - models.ForeignKey( - on_delete=django.db.models.deletion.CASCADE, - related_name="domains", - to="account.organization", - ), - ), - ], - options={ - "abstract": False, - }, - ), - ] diff --git a/backend/account/migrations/0002_auto_20230718_1040.py b/backend/account/migrations/0002_auto_20230718_1040.py deleted file mode 100644 index 4a6059157..000000000 --- a/backend/account/migrations/0002_auto_20230718_1040.py +++ /dev/null @@ -1,40 +0,0 @@ -# mypy: ignore-errors -# Generated by Django 4.2.1 on 2023-07-18 10:40 - -from django.conf import settings -from django.contrib.auth.hashers import make_password -from django.db import migrations - - -def create_public_tenant_and_domain(apps, schema_editor): - organization_model = apps.get_model("account", "Organization") - - # public tenant - tenant = organization_model( - name="public", - display_name="public", - organization_id="public", - schema_name="public", - ) - tenant.save() - - user_model = apps.get_model("account", "User") - # public User admin - user = user_model( - username=settings.SYSTEM_ADMIN_USERNAME, - email=settings.SYSTEM_ADMIN_EMAIL, - is_superuser=True, - is_staff=True, - password=make_password(settings.SYSTEM_ADMIN_PASSWORD), - ) - user.save() - - -class Migration(migrations.Migration): - dependencies = [ - ("account", "0001_initial"), - ] - - operations = [ - migrations.RunPython(create_public_tenant_and_domain), - ] diff --git a/backend/account/migrations/0003_platformkey.py b/backend/account/migrations/0003_platformkey.py deleted file mode 100644 index f01cc5295..000000000 --- a/backend/account/migrations/0003_platformkey.py +++ /dev/null @@ -1,66 +0,0 @@ -# Generated by Django 4.2.1 on 2023-11-02 05:22 - -import uuid - -import django.db.models.deletion -from django.conf import settings -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("account", "0002_auto_20230718_1040"), - ] - - operations = [ - migrations.CreateModel( - name="PlatformKey", - fields=[ - ( - "id", - models.UUIDField( - default=uuid.uuid4, - editable=False, - primary_key=True, - serialize=False, - ), - ), - ("key", models.UUIDField(default=uuid.uuid4)), - ( - "key_name", - models.CharField(blank=True, max_length=64, null=True, unique=True), - ), - ("is_active", models.BooleanField(default=False)), - ( - "created_by", - models.ForeignKey( - blank=True, - null=True, - on_delete=django.db.models.deletion.SET_NULL, - related_name="created_keys", - to=settings.AUTH_USER_MODEL, - ), - ), - ( - "modified_by", - models.ForeignKey( - blank=True, - null=True, - on_delete=django.db.models.deletion.SET_NULL, - related_name="modified_keys", - to=settings.AUTH_USER_MODEL, - ), - ), - ( - "organization", - models.ForeignKey( - blank=True, - null=True, - on_delete=django.db.models.deletion.SET_NULL, - related_name="related_org", - to="account.organization", - ), - ), - ], - ), - ] diff --git a/backend/account/migrations/0004_alter_platformkey_key_name_and_more.py b/backend/account/migrations/0004_alter_platformkey_key_name_and_more.py deleted file mode 100644 index 86df7e29b..000000000 --- a/backend/account/migrations/0004_alter_platformkey_key_name_and_more.py +++ /dev/null @@ -1,23 +0,0 @@ -# Generated by Django 4.2.1 on 2023-11-15 11:37 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("account", "0003_platformkey"), - ] - - operations = [ - migrations.AlterField( - model_name="platformkey", - name="key_name", - field=models.CharField(blank=True, default="", max_length=64), - ), - migrations.AddConstraint( - model_name="platformkey", - constraint=models.UniqueConstraint( - fields=("key_name", "organization"), name="unique_key_name" - ), - ), - ] diff --git a/backend/account/migrations/0005_encryptionsecret.py b/backend/account/migrations/0005_encryptionsecret.py deleted file mode 100644 index 724c12522..000000000 --- a/backend/account/migrations/0005_encryptionsecret.py +++ /dev/null @@ -1,27 +0,0 @@ -# Generated by Django 4.2.1 on 2024-02-13 11:52 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("account", "0004_alter_platformkey_key_name_and_more"), - ] - - operations = [ - migrations.CreateModel( - name="EncryptionSecret", - fields=[ - ( - "id", - models.BigAutoField( - auto_created=True, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - ), - ("key", models.CharField(blank=True, max_length=64)), - ], - ), - ] diff --git a/backend/account/migrations/0006_delete_encryptionsecret.py b/backend/account/migrations/0006_delete_encryptionsecret.py deleted file mode 100644 index 1216373ee..000000000 --- a/backend/account/migrations/0006_delete_encryptionsecret.py +++ /dev/null @@ -1,15 +0,0 @@ -# Generated by Django 4.2.1 on 2024-03-04 05:06 - -from django.db import migrations - - -class Migration(migrations.Migration): - dependencies = [ - ("account", "0005_encryptionsecret"), - ] - - operations = [ - migrations.DeleteModel( - name="EncryptionSecret", - ), - ] diff --git a/backend/account/migrations/0007_organization_allowed_token_limit.py b/backend/account/migrations/0007_organization_allowed_token_limit.py deleted file mode 100644 index af6baa8a1..000000000 --- a/backend/account/migrations/0007_organization_allowed_token_limit.py +++ /dev/null @@ -1,20 +0,0 @@ -# Generated by Django 4.2.1 on 2024-04-25 07:55 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("account", "0006_delete_encryptionsecret"), - ] - - operations = [ - migrations.AddField( - model_name="organization", - name="allowed_token_limit", - field=models.IntegerField( - db_comment="token limit set in case of frition less onbaoarded org", - default=-1, - ), - ), - ] diff --git a/backend/account/migrations/__init__.py b/backend/account/migrations/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/backend/account/models.py b/backend/account/models.py deleted file mode 100644 index 063a7f7cd..000000000 --- a/backend/account/models.py +++ /dev/null @@ -1,136 +0,0 @@ -import uuid - -from django.contrib.auth.models import AbstractUser, Group, Permission -from django.db import models -from django_tenants.models import DomainMixin, TenantMixin - -from backend.constants import FieldLengthConstants as FieldLength - -NAME_SIZE = 64 -KEY_SIZE = 64 - - -class Organization(TenantMixin): - """Stores data related to an organization. - - The fields created_by and modified_by is updated after a - :model:`account.User` is created. - """ - - name = models.CharField(max_length=NAME_SIZE) - display_name = models.CharField(max_length=NAME_SIZE) - organization_id = models.CharField(max_length=FieldLength.ORG_NAME_SIZE) - created_by = models.ForeignKey( - "User", - on_delete=models.SET_NULL, - related_name="created_orgs", - null=True, - blank=True, - ) - modified_by = models.ForeignKey( - "User", - on_delete=models.SET_NULL, - related_name="modified_orgs", - null=True, - blank=True, - ) - modified_at = models.DateTimeField(auto_now=True) - created_at = models.DateTimeField(auto_now=True) - allowed_token_limit = models.IntegerField( - default=-1, - db_comment="token limit set in case of frition less onbaoarded org", - ) - - auto_create_schema = True - - -class Domain(DomainMixin): - pass - - -class User(AbstractUser): - """Stores data related to a user belonging to any organization. - - Every org, user is assumed to be unique. - """ - - # Third Party Authentication User ID - user_id = models.CharField() - project_storage_created = models.BooleanField(default=False) - created_by = models.ForeignKey( - "User", - on_delete=models.SET_NULL, - related_name="created_users", - null=True, - blank=True, - ) - modified_by = models.ForeignKey( - "User", - on_delete=models.SET_NULL, - related_name="modified_users", - null=True, - blank=True, - ) - modified_at = models.DateTimeField(auto_now=True) - created_at = models.DateTimeField(auto_now_add=True) - - # Specify a unique related_name for the groups field - groups = models.ManyToManyField( - Group, - related_name="customuser_set", - related_query_name="customuser", - blank=True, - ) - - # Specify a unique related_name for the user_permissions field - user_permissions = models.ManyToManyField( - Permission, - related_name="customuser_set", - related_query_name="customuser", - blank=True, - ) - - def __str__(self): # type: ignore - return f"User({self.id}, email: {self.email}, userId: {self.user_id})" - - -class PlatformKey(models.Model): - """Model to hold details of Platform keys. - - Only users with admin role are allowed to perform any operation - related keys. - """ - - id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) - key = models.UUIDField(default=uuid.uuid4) - key_name = models.CharField(max_length=KEY_SIZE, null=False, blank=True, default="") - is_active = models.BooleanField(default=False) - organization = models.ForeignKey( - "Organization", - on_delete=models.SET_NULL, - related_name="related_org", - null=True, - blank=True, - ) - created_by = models.ForeignKey( - "User", - on_delete=models.SET_NULL, - related_name="created_keys", - null=True, - blank=True, - ) - modified_by = models.ForeignKey( - "User", - on_delete=models.SET_NULL, - related_name="modified_keys", - null=True, - blank=True, - ) - - class Meta: - constraints = [ - models.UniqueConstraint( - fields=["key_name", "organization"], - name="unique_key_name", - ), - ] diff --git a/backend/account/organization.py b/backend/account/organization.py deleted file mode 100644 index 5cd547b3f..000000000 --- a/backend/account/organization.py +++ /dev/null @@ -1,52 +0,0 @@ -import logging -from typing import Optional - -from account.models import Domain, Organization -from account.subscription_loader import SubscriptionConfig, load_plugins -from django.db import IntegrityError - -Logger = logging.getLogger(__name__) - -subscription_loader = load_plugins() - - -class OrganizationService: - def __init__(self): # type: ignore - pass - - @staticmethod - def get_organization_by_org_id(org_id: str) -> Optional[Organization]: - try: - return Organization.objects.get(organization_id=org_id) # type: ignore - except Organization.DoesNotExist: - return None - - @staticmethod - def create_organization( - name: str, display_name: str, organization_id: str - ) -> Organization: - try: - organization: Organization = Organization( - name=name, - display_name=display_name, - organization_id=organization_id, - schema_name=organization_id, - ) - organization.save() - - for subscription_plugin in subscription_loader: - cls = subscription_plugin[SubscriptionConfig.METADATA][ - SubscriptionConfig.METADATA_SERVICE_CLASS - ] - cls.add(organization_id=organization_id) - - except IntegrityError as error: - Logger.info(f"[Duplicate Id] Failed to create Organization Error: {error}") - raise error - # Add one or more domains for the tenant - domain = Domain() - domain.domain = organization_id - domain.tenant = organization - domain.is_primary = True - domain.save() - return organization diff --git a/backend/account/serializer.py b/backend/account/serializer.py deleted file mode 100644 index c40c0b993..000000000 --- a/backend/account/serializer.py +++ /dev/null @@ -1,119 +0,0 @@ -import re -from typing import Optional - -# from account.enums import Region -from account.models import Organization, User -from rest_framework import serializers - - -class OrganizationSignupSerializer(serializers.Serializer): - name = serializers.CharField(required=True, max_length=150) - display_name = serializers.CharField(required=True, max_length=150) - organization_id = serializers.CharField(required=True, max_length=30) - - def validate_organization_id(self, value): # type: ignore - if not re.match(r"^[a-z0-9_-]+$", value): - raise serializers.ValidationError( - "organization_code should only contain " - "alphanumeric characters,_ and -." - ) - return value - - -class OrganizationCallbackSerializer(serializers.Serializer): - id = serializers.CharField(required=False) - - -class GetOrganizationsResponseSerializer(serializers.Serializer): - id = serializers.CharField() - display_name = serializers.CharField() - name = serializers.CharField() - # Add more fields as needed - - def to_representation(self, instance): # type: ignore - data = super().to_representation(instance) - # Modify the representation if needed - return data - - -class GetOrganizationMembersResponseSerializer(serializers.Serializer): - user_id = serializers.CharField() - email = serializers.CharField() - name = serializers.CharField() - picture = serializers.CharField() - # Add more fields as needed - - def to_representation(self, instance): # type: ignore - data = super().to_representation(instance) - # Modify the representation if needed - return data - - -class OrganizationSerializer(serializers.Serializer): - name = serializers.CharField() - organization_id = serializers.CharField() - - -class SetOrganizationsResponseSerializer(serializers.Serializer): - id = serializers.CharField() - email = serializers.CharField() - name = serializers.CharField() - display_name = serializers.CharField() - family_name = serializers.CharField() - picture = serializers.CharField() - # Add more fields as needed - - def to_representation(self, instance): # type: ignore - data = super().to_representation(instance) - # Modify the representation if needed - return data - - -class ModelTenantSerializer(serializers.ModelSerializer): - class Meta: - model = Organization - fields = fields = ("name", "created_on") - - -class UserSerializer(serializers.ModelSerializer): - class Meta: - model = User - fields = ("id", "username") - - -class OrganizationSignupResponseSerializer(serializers.Serializer): - name = serializers.CharField() - display_name = serializers.CharField() - organization_id = serializers.CharField() - created_at = serializers.CharField() - - -class LoginRequestSerializer(serializers.Serializer): - username = serializers.CharField(required=True) - password = serializers.CharField(required=True) - - def validate_username(self, value: Optional[str]) -> str: - """Check that the username is not empty and has at least 3 - characters.""" - if not value or len(value) < 3: - raise serializers.ValidationError( - "Username must be at least 3 characters long." - ) - return value - - def validate_password(self, value: Optional[str]) -> str: - """Check that the password is not empty and has at least 3 - characters.""" - if not value or len(value) < 3: - raise serializers.ValidationError( - "Password must be at least 3 characters long." - ) - return value - - -class UserSessionResponseSerializer(serializers.Serializer): - id = serializers.IntegerField() - user_id = serializers.CharField() - email = serializers.CharField() - organization_id = serializers.CharField() - role = serializers.CharField() diff --git a/backend/account/subscription_loader.py b/backend/account/subscription_loader.py deleted file mode 100644 index 133a4d9eb..000000000 --- a/backend/account/subscription_loader.py +++ /dev/null @@ -1,107 +0,0 @@ -import logging -import os -from importlib import import_module -from typing import Any - -from django.apps import apps -from django.utils import timezone - -logger = logging.getLogger(__name__) - - -class SubscriptionConfig: - """Loader config for subscription plugins.""" - - PLUGINS_APP = "plugins" - PLUGIN_DIR = "subscription" - MODULE = "module" - METADATA = "metadata" - METADATA_NAME = "name" - METADATA_SERVICE_CLASS = "service_class" - METADATA_IS_ACTIVE = "is_active" - - -def load_plugins() -> list[Any]: - """Iterate through the subscription plugins and register them.""" - plugins_app = apps.get_app_config(SubscriptionConfig.PLUGINS_APP) - package_path = plugins_app.module.__package__ - subscription_dir = os.path.join(plugins_app.path, SubscriptionConfig.PLUGIN_DIR) - subscription_package_path = f"{package_path}.{SubscriptionConfig.PLUGIN_DIR}" - subscription_plugins: list[Any] = [] - - if not os.path.exists(subscription_dir): - return subscription_plugins - - for item in os.listdir(subscription_dir): - # Loads a plugin if it is in a directory. - if os.path.isdir(os.path.join(subscription_dir, item)): - subscription_module_name = item - # Loads a plugin if it is a shared library. - # Module name is extracted from shared library name. - # `subscription.platform_architecture.so` will be file name and - # `subscription` will be the module name. - elif item.endswith(".so"): - subscription_module_name = item.split(".")[0] - else: - continue - try: - full_module_path = f"{subscription_package_path}.{subscription_module_name}" - module = import_module(full_module_path) - metadata = getattr(module, SubscriptionConfig.METADATA, {}) - - if metadata.get(SubscriptionConfig.METADATA_IS_ACTIVE, False): - subscription_plugins.append( - { - SubscriptionConfig.MODULE: module, - SubscriptionConfig.METADATA: module.metadata, - } - ) - logger.info( - "Loaded subscription plugin: %s, is_active: %s", - module.metadata[SubscriptionConfig.METADATA_NAME], - module.metadata[SubscriptionConfig.METADATA_IS_ACTIVE], - ) - else: - logger.info( - "subscription plugin %s is not active.", - subscription_module_name, - ) - except ModuleNotFoundError as exception: - logger.error( - "Error while importing subscription plugin: %s", - exception, - ) - - if len(subscription_plugins) == 0: - logger.info("No subscription plugins found.") - - return subscription_plugins - - -def validate_etl_run(org_id: str) -> bool: - """Method to check subscription status before ETL runs. - - Args: - org_id: The ID of the organization. - - Returns: - A boolean indicating whether the pre-run check passed or not. - """ - try: - from pluggable_apps.subscription.subscription_helper import SubscriptionHelper - except ModuleNotFoundError: - logger.error("Subscription plugin not found.") - return False - - org_plans = SubscriptionHelper.get_subscription(org_id) - if not org_plans or not org_plans.is_active: - return False - - if org_plans.is_paid: - return True - - if timezone.now() >= org_plans.end_date: - logger.debug(f"Trial expired for org {org_id}") - return False - - return True diff --git a/backend/account/templates/index.html b/backend/account/templates/index.html deleted file mode 100644 index ffa0b6085..000000000 --- a/backend/account/templates/index.html +++ /dev/null @@ -1,11 +0,0 @@ - - - - - ZipstackID Django App Example - - -

Welcome Guest

-

Login

- - diff --git a/backend/account/templates/login.html b/backend/account/templates/login.html deleted file mode 100644 index 4edf8e3f1..000000000 --- a/backend/account/templates/login.html +++ /dev/null @@ -1,134 +0,0 @@ - - - - - - Login - - - -
-
- -
-
- {% load static %} -
- My image -
- - {% if error_message %} -

{{ error_message }}

- {% endif %} - {% csrf_token %} - - - -

- -
- - - diff --git a/backend/account/tests.py b/backend/account/tests.py deleted file mode 100644 index a39b155ac..000000000 --- a/backend/account/tests.py +++ /dev/null @@ -1 +0,0 @@ -# Create your tests here. diff --git a/backend/account/urls.py b/backend/account/urls.py deleted file mode 100644 index aeb10707f..000000000 --- a/backend/account/urls.py +++ /dev/null @@ -1,22 +0,0 @@ -from account.views import ( - callback, - create_organization, - get_organizations, - get_session_data, - login, - logout, - set_organization, - signup, -) -from django.urls import path - -urlpatterns = [ - path("login", login, name="login"), - path("signup", signup, name="signup"), - path("logout", logout, name="logout"), - path("callback", callback, name="callback"), - path("session", get_session_data, name="session"), - path("organization", get_organizations, name="get_organizations"), - path("organization//set", set_organization, name="set_organization"), - path("organization/create", create_organization, name="create_organization"), -] diff --git a/backend/account/user.py b/backend/account/user.py deleted file mode 100644 index c1499e69c..000000000 --- a/backend/account/user.py +++ /dev/null @@ -1,55 +0,0 @@ -import logging -from typing import Any, Optional - -from account.models import User -from django.db import IntegrityError - -Logger = logging.getLogger(__name__) - - -class UserService: - def __init__( - self, - ) -> None: - pass - - def create_user(self, email: str, user_id: str) -> User: - try: - user: User = User(email=email, user_id=user_id, username=email) - user.save() - except IntegrityError as error: - Logger.info(f"[Duplicate Id] Failed to create User Error: {error}") - raise error - return user - - def update_user(self, user: User, user_id: str) -> User: - user.user_id = user_id - user.save() - return user - - def get_user_by_email(self, email: str) -> Optional[User]: - try: - user: User = User.objects.get(email=email) - return user - except User.DoesNotExist: - return None - - def get_user_by_user_id(self, user_id: str) -> Any: - try: - return User.objects.get(user_id=user_id) - except User.DoesNotExist: - return None - - def get_user_by_id(self, id: str) -> Any: - """Retrieve a user by their ID, taking into account the schema context. - - Args: - id (str): The ID of the user. - - Returns: - Any: The user object if found, or None if not found. - """ - try: - return User.objects.get(id=id) - except User.DoesNotExist: - return None diff --git a/backend/account/views.py b/backend/account/views.py deleted file mode 100644 index 0e6582821..000000000 --- a/backend/account/views.py +++ /dev/null @@ -1,170 +0,0 @@ -import logging -from typing import Any - -from account.authentication_controller import AuthenticationController -from account.dto import ( - OrganizationSignupRequestBody, - OrganizationSignupResponse, - UserSessionInfo, -) -from account.models import Organization -from account.organization import OrganizationService -from account.serializer import ( - OrganizationSignupResponseSerializer, - OrganizationSignupSerializer, - UserSessionResponseSerializer, -) -from rest_framework import status -from rest_framework.decorators import api_view -from rest_framework.request import Request -from rest_framework.response import Response -from utils.user_session import UserSessionUtils - -Logger = logging.getLogger(__name__) - - -@api_view(["POST"]) -def create_organization(request: Request) -> Response: - serializer = OrganizationSignupSerializer(data=request.data) - serializer.is_valid(raise_exception=True) - try: - requestBody: OrganizationSignupRequestBody = makeSignupRequestParams(serializer) - - organization: Organization = OrganizationService.create_organization( - requestBody.name, - requestBody.display_name, - requestBody.organization_id, - ) - response = makeSignupResponse(organization) - return Response( - status=status.HTTP_201_CREATED, - data={"message": "success", "tenant": response}, - ) - except Exception as error: - Logger.error(error) - return Response( - status=status.HTTP_500_INTERNAL_SERVER_ERROR, data="Unknown Error" - ) - - -@api_view(["GET"]) -def callback(request: Request) -> Response: - auth_controller = AuthenticationController() - return auth_controller.authorization_callback(request) - - -@api_view(["GET", "POST"]) -def login(request: Request) -> Response: - auth_controller = AuthenticationController() - return auth_controller.user_login(request) - - -@api_view(["GET"]) -def signup(request: Request) -> Response: - auth_controller = AuthenticationController() - return auth_controller.user_signup(request) - - -@api_view(["GET"]) -def logout(request: Request) -> Response: - auth_controller = AuthenticationController() - return auth_controller.user_logout(request) - - -@api_view(["GET"]) -def get_organizations(request: Request) -> Response: - """get_organizations. - - Retrieve the list of organizations to which the user belongs. - Args: - request (HttpRequest): _description_ - - Returns: - Response: A list of organizations with associated information. - """ - auth_controller = AuthenticationController() - return auth_controller.user_organizations(request) - - -@api_view(["POST"]) -def set_organization(request: Request, id: str) -> Response: - """set_organization. - - Set the current organization to use. - Args: - request (HttpRequest): _description_ - id (String): organization Id - - Returns: - Response: Contains the User and Current organization details. - """ - - auth_controller = AuthenticationController() - return auth_controller.set_user_organization(request, id) - - -@api_view(["GET"]) -def get_session_data(request: Request) -> Response: - """get_session_data. - - Retrieve the current session data. - Args: - request (HttpRequest): _description_ - - Returns: - Response: Contains the User and Current organization details. - """ - response = make_session_response(request) - - return Response( - status=status.HTTP_201_CREATED, - data=response, - ) - - -def make_session_response( - request: Request, -) -> Any: - """make_session_response. - - Make the current session data. - Args: - request (HttpRequest): _description_ - - Returns: - User and Current organization details. - """ - auth_controller = AuthenticationController() - return UserSessionResponseSerializer( - UserSessionInfo( - id=request.user.id, - user_id=request.user.user_id, - email=request.user.email, - user=auth_controller.get_user_info(request), - organization_id=UserSessionUtils.get_organization_id(request), - role=UserSessionUtils.get_organization_member_role(request), - ) - ).data - - -def makeSignupRequestParams( - serializer: OrganizationSignupSerializer, -) -> OrganizationSignupRequestBody: - return OrganizationSignupRequestBody( - serializer.validated_data["name"], - serializer.validated_data["display_name"], - serializer.validated_data["organization_id"], - ) - - -def makeSignupResponse( - organization: Organization, -) -> Any: - return OrganizationSignupResponseSerializer( - OrganizationSignupResponse( - organization.name, - organization.display_name, - organization.organization_id, - organization.created_at, - ) - ).data diff --git a/backend/adapter_processor/__init__.py b/backend/adapter_processor/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/backend/adapter_processor/adapter_processor.py b/backend/adapter_processor/adapter_processor.py deleted file mode 100644 index 4b3407157..000000000 --- a/backend/adapter_processor/adapter_processor.py +++ /dev/null @@ -1,250 +0,0 @@ -import json -import logging -from typing import Any, Optional - -from account.models import User -from adapter_processor.constants import AdapterKeys -from adapter_processor.exceptions import ( - InternalServiceError, - InValidAdapterId, - TestAdapterError, -) -from django.conf import settings -from django.core.exceptions import ObjectDoesNotExist -from platform_settings.platform_auth_service import PlatformAuthenticationService -from unstract.sdk.adapters.adapterkit import Adapterkit -from unstract.sdk.adapters.base import Adapter -from unstract.sdk.adapters.enums import AdapterTypes -from unstract.sdk.adapters.x2text.constants import X2TextConstants -from unstract.sdk.exceptions import SdkError - -from .models import AdapterInstance, UserDefaultAdapter - -logger = logging.getLogger(__name__) - - -class AdapterProcessor: - @staticmethod - def get_json_schema(adapter_id: str) -> dict[str, Any]: - """Function to return JSON Schema for Adapters.""" - schema_details: dict[str, Any] = {} - updated_adapters = AdapterProcessor.__fetch_adapters_by_key_value( - AdapterKeys.ID, adapter_id - ) - if len(updated_adapters) != 0: - schema_details[AdapterKeys.JSON_SCHEMA] = json.loads( - updated_adapters[0].get(AdapterKeys.JSON_SCHEMA) - ) - else: - logger.error( - f"Invalid adapter Id : {adapter_id} while fetching JSON Schema" - ) - raise InValidAdapterId() - return schema_details - - @staticmethod - def get_all_supported_adapters(type: str) -> list[dict[Any, Any]]: - """Function to return list of all supported adapters.""" - supported_adapters = [] - updated_adapters = [] - updated_adapters = AdapterProcessor.__fetch_adapters_by_key_value( - AdapterKeys.ADAPTER_TYPE, type - ) - for each_adapter in updated_adapters: - supported_adapters.append( - { - AdapterKeys.ID: each_adapter.get(AdapterKeys.ID), - AdapterKeys.NAME: each_adapter.get(AdapterKeys.NAME), - AdapterKeys.DESCRIPTION: each_adapter.get(AdapterKeys.DESCRIPTION), - AdapterKeys.ICON: each_adapter.get(AdapterKeys.ICON), - AdapterKeys.ADAPTER_TYPE: each_adapter.get( - AdapterKeys.ADAPTER_TYPE - ), - } - ) - return supported_adapters - - @staticmethod - def get_adapter_data_with_key(adapter_id: str, key_value: str) -> Any: - """Generic Function to get adapter data with provided key.""" - updated_adapters = AdapterProcessor.__fetch_adapters_by_key_value( - "id", adapter_id - ) - if len(updated_adapters) == 0: - logger.error(f"Invalid adapter ID {adapter_id} while invoking utility") - raise InValidAdapterId() - return updated_adapters[0].get(key_value) - - @staticmethod - def test_adapter(adapter_id: str, adapter_metadata: dict[str, Any]) -> bool: - logger.info(f"Testing adapter: {adapter_id}") - try: - adapter_class = Adapterkit().get_adapter_class_by_adapter_id(adapter_id) - - if adapter_metadata.pop(AdapterKeys.ADAPTER_TYPE) == AdapterKeys.X2TEXT: - adapter_metadata[X2TextConstants.X2TEXT_HOST] = settings.X2TEXT_HOST - adapter_metadata[X2TextConstants.X2TEXT_PORT] = settings.X2TEXT_PORT - platform_key = PlatformAuthenticationService.get_active_platform_key() - adapter_metadata[X2TextConstants.PLATFORM_SERVICE_API_KEY] = str( - platform_key.key - ) - - adapter_instance = adapter_class(adapter_metadata) - test_result: bool = adapter_instance.test_connection() - logger.info(f"{adapter_id} test result: {test_result}") - return test_result - except SdkError as e: - raise TestAdapterError(str(e)) - - @staticmethod - def __fetch_adapters_by_key_value(key: str, value: Any) -> Adapter: - """Fetches a list of adapters that have an attribute matching key and - value.""" - logger.info(f"Fetching adapter list for {key} with {value}") - adapter_kit = Adapterkit() - adapters = adapter_kit.get_adapters_list() - return [iterate for iterate in adapters if iterate[key] == value] - - @staticmethod - def set_default_triad(default_triad: dict[str, str], user: User) -> None: - try: - ( - user_default_adapter, - created, - ) = UserDefaultAdapter.objects.get_or_create(user=user) - - if default_triad.get(AdapterKeys.LLM_DEFAULT, None): - user_default_adapter.default_llm_adapter = AdapterInstance.objects.get( - pk=default_triad[AdapterKeys.LLM_DEFAULT] - ) - if default_triad.get(AdapterKeys.EMBEDDING_DEFAULT, None): - user_default_adapter.default_embedding_adapter = ( - AdapterInstance.objects.get( - pk=default_triad[AdapterKeys.EMBEDDING_DEFAULT] - ) - ) - - if default_triad.get(AdapterKeys.VECTOR_DB_DEFAULT, None): - user_default_adapter.default_vector_db_adapter = ( - AdapterInstance.objects.get( - pk=default_triad[AdapterKeys.VECTOR_DB_DEFAULT] - ) - ) - - if default_triad.get(AdapterKeys.X2TEXT_DEFAULT, None): - user_default_adapter.default_x2text_adapter = ( - AdapterInstance.objects.get( - pk=default_triad[AdapterKeys.X2TEXT_DEFAULT] - ) - ) - - user_default_adapter.save() - - logger.info("Changed defaults successfully") - except Exception as e: - logger.error(f"Unable to save defaults because: {e}") - if isinstance(e, InValidAdapterId): - raise e - else: - raise InternalServiceError() - - @staticmethod - def get_adapter_instance_by_id(adapter_instance_id: str) -> Adapter: - """Get the adapter instance by its ID. - - Parameters: - - adapter_instance_id (str): The ID of the adapter instance. - - Returns: - - Adapter: The adapter instance with the specified ID. - - Raises: - - Exception: If there is an error while fetching the adapter instance. - """ - try: - adapter = AdapterInstance.objects.get(id=adapter_instance_id) - except Exception as e: - logger.error(f"Unable to fetch adapter: {e}") - if not adapter: - logger.error("Unable to fetch adapter") - return adapter.adapter_name - - @staticmethod - def get_adapters_by_type( - adapter_type: AdapterTypes, user: User - ) -> list[AdapterInstance]: - """Get a list of adapters by their type. - - Parameters: - - adapter_type (AdapterTypes): The type of adapters to retrieve. - - user: Logged in User - - Returns: - - list[AdapterInstance]: A list of AdapterInstance objects that match - the specified adapter type. - """ - - adapters: list[AdapterInstance] = AdapterInstance.objects.for_user(user).filter( - adapter_type=adapter_type.value, - ) - return adapters - - @staticmethod - def get_adapter_by_name_and_type( - adapter_type: AdapterTypes, - adapter_name: Optional[str] = None, - ) -> Optional[AdapterInstance]: - """Get the adapter instance by its name and type. - - Parameters: - - adapter_name (str): The name of the adapter instance. - - adapter_type (AdapterTypes): The type of the adapter instance. - - Returns: - - AdapterInstance: The adapter with the specified name and type. - """ - if adapter_name: - adapter: AdapterInstance = AdapterInstance.objects.get( - adapter_name=adapter_name, adapter_type=adapter_type.value - ) - else: - try: - adapter = AdapterInstance.objects.get( - adapter_type=adapter_type.value, is_default=True - ) - except AdapterInstance.DoesNotExist: - return None - return adapter - - @staticmethod - def get_default_adapters(user: User) -> list[AdapterInstance]: - """Retrieve a list of default adapter instances. This method queries - the database to fetch all adapter instances marked as default. - - Raises: - InternalServiceError: If an unexpected error occurs during - the database query. - - Returns: - list[AdapterInstance]: A list of AdapterInstance objects that are - marked as default. - """ - try: - adapters: list[AdapterInstance] = [] - default_adapter = UserDefaultAdapter.objects.get(user=user) - - if default_adapter.default_embedding_adapter: - adapters.append(default_adapter.default_embedding_adapter) - if default_adapter.default_llm_adapter: - adapters.append(default_adapter.default_llm_adapter) - if default_adapter.default_vector_db_adapter: - adapters.append(default_adapter.default_vector_db_adapter) - if default_adapter.default_x2text_adapter: - adapters.append(default_adapter.default_x2text_adapter) - - return adapters - except ObjectDoesNotExist as e: - logger.error(f"No default adapters found: {e}") - raise InternalServiceError( - "No default adapters found, configure them through Platform Settings" - ) diff --git a/backend/adapter_processor/constants.py b/backend/adapter_processor/constants.py deleted file mode 100644 index 6557491b9..000000000 --- a/backend/adapter_processor/constants.py +++ /dev/null @@ -1,29 +0,0 @@ -class AdapterKeys: - JSON_SCHEMA = "json_schema" - ADAPTER_TYPE = "adapter_type" - IS_DEFAULT = "is_default" - LLM = "LLM" - X2TEXT = "X2TEXT" - OCR = "OCR" - VECTOR_DB = "VECTOR_DB" - EMBEDDING = "EMBEDDING" - NAME = "name" - DESCRIPTION = "description" - ICON = "icon" - ADAPTER_ID = "adapter_id" - ADAPTER_METADATA = "adapter_metadata" - ADAPTER_METADATA_B = "adapter_metadata_b" - ID = "id" - IS_VALID = "is_valid" - LLM_DEFAULT = "llm_default" - VECTOR_DB_DEFAULT = "vector_db_default" - EMBEDDING_DEFAULT = "embedding_default" - X2TEXT_DEFAULT = "x2text_default" - SHARED_USERS = "shared_users" - ADAPTER_NAME_EXISTS = ( - "Configuration with this name already exists within your organisation. " - "Please try with a different name." - ) - ADAPTER_NAME = "adapter_name" - ADAPTER_CREATED_BY = "created_by_email" - ADAPTER_CONTEXT_WINDOW_SIZE = "context_window_size" diff --git a/backend/adapter_processor/exceptions.py b/backend/adapter_processor/exceptions.py deleted file mode 100644 index 876775ca5..000000000 --- a/backend/adapter_processor/exceptions.py +++ /dev/null @@ -1,76 +0,0 @@ -from typing import Optional - -from adapter_processor.constants import AdapterKeys -from rest_framework.exceptions import APIException - - -class IdIsMandatory(APIException): - status_code = 400 - default_detail = "ID is Mandatory." - - -class InValidType(APIException): - status_code = 400 - default_detail = "Type is not Valid." - - -class InValidAdapterId(APIException): - status_code = 400 - default_detail = "Adapter ID is not Valid." - - -class InternalServiceError(APIException): - status_code = 500 - default_detail = "Internal Service error" - - -class CannotDeleteDefaultAdapter(APIException): - status_code = 500 - default_detail = ( - "This is configured as default and cannot be deleted. " - "Please configure a different default before you try again!" - ) - - -class DuplicateAdapterNameError(APIException): - status_code = 400 - default_detail: str = AdapterKeys.ADAPTER_NAME_EXISTS - - def __init__( - self, - name: Optional[str] = None, - detail: Optional[str] = None, - code: Optional[str] = None, - ) -> None: - if name: - detail = self.default_detail.replace("this name", f"name '{name}'") - super().__init__(detail, code) - - -class TestAdapterError(APIException): - status_code = 500 - default_detail = "Error while testing adapter" - - -class TestAdapterInputError(APIException): - status_code = 400 - default_detail = "Error while testing adapter, please check the configuration." - - -class DeleteAdapterInUseError(APIException): - status_code = 409 - - def __init__( - self, - detail: Optional[str] = None, - code: Optional[str] = None, - adapter_name: str = "adapter", - ): - if detail is None: - if adapter_name != "adapter": - adapter_name = f"'{adapter_name}'" - detail = ( - f"Cannot delete {adapter_name}. " - "It is used in a workflow or a prompt studio project" - ) - super().__init__(detail, code) diff --git a/backend/adapter_processor/migrations/0001_initial.py b/backend/adapter_processor/migrations/0001_initial.py deleted file mode 100644 index 1b954e60b..000000000 --- a/backend/adapter_processor/migrations/0001_initial.py +++ /dev/null @@ -1,109 +0,0 @@ -# Generated by Django 4.2.1 on 2024-01-23 11:18 - -import uuid - -import django.db.models.deletion -from django.conf import settings -from django.db import migrations, models - - -class Migration(migrations.Migration): - initial = True - - dependencies = [ - migrations.swappable_dependency(settings.AUTH_USER_MODEL), - ] - - operations = [ - migrations.CreateModel( - name="AdapterInstance", - fields=[ - ("created_at", models.DateTimeField(auto_now_add=True)), - ("modified_at", models.DateTimeField(auto_now=True)), - ( - "id", - models.UUIDField( - db_comment="Unique identifier for the Adapter Instance", - default=uuid.uuid4, - editable=False, - primary_key=True, - serialize=False, - ), - ), - ( - "adapter_name", - models.TextField( - db_comment="Name of the Adapter Instance", - max_length=128, - ), - ), - ( - "adapter_id", - models.CharField( - db_comment="Unique identifier of the Adapter", - default="", - max_length=128, - ), - ), - ( - "adapter_metadata", - models.JSONField( - db_column="adapter_metadata", - db_comment="JSON adapter metadata submitted by the user", - default=dict, - ), - ), - ( - "adapter_type", - models.CharField( - choices=[ - ("UNKNOWN", "UNKNOWN"), - ("LLM", "LLM"), - ("EMBEDDING", "EMBEDDING"), - ("VECTOR_DB", "VECTOR_DB"), - ], - db_comment="Type of adapter LLM/EMBEDDING/VECTOR_DB", - ), - ), - ( - "is_active", - models.BooleanField( - db_comment="Is the adapter instance currently being used", - default=False, - ), - ), - ( - "is_default", - models.BooleanField( - db_comment="Is the adapter instance default", - default=False, - ), - ), - ( - "created_by", - models.ForeignKey( - blank=True, - null=True, - on_delete=django.db.models.deletion.SET_NULL, - related_name="created_adapters", - to=settings.AUTH_USER_MODEL, - ), - ), - ( - "modified_by", - models.ForeignKey( - blank=True, - null=True, - on_delete=django.db.models.deletion.SET_NULL, - related_name="modified_adapters", - to=settings.AUTH_USER_MODEL, - ), - ), - ], - options={ - "verbose_name": "adapter_adapterinstance", - "verbose_name_plural": "adapter_adapterinstance", - "db_table": "adapter_adapterinstance", - }, - ), - ] diff --git a/backend/adapter_processor/migrations/0002_adapterinstance_unique_adapter.py b/backend/adapter_processor/migrations/0002_adapterinstance_unique_adapter.py deleted file mode 100644 index 5b1c46593..000000000 --- a/backend/adapter_processor/migrations/0002_adapterinstance_unique_adapter.py +++ /dev/null @@ -1,18 +0,0 @@ -# Generated by Django 4.2.1 on 2024-01-20 08:32 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("adapter_processor", "0001_initial"), - ] - - operations = [ - migrations.AddConstraint( - model_name="adapterinstance", - constraint=models.UniqueConstraint( - fields=("adapter_name", "adapter_type"), name="unique_adapter" - ), - ), - ] diff --git a/backend/adapter_processor/migrations/0003_adapterinstance_adapter_metadata_b.py b/backend/adapter_processor/migrations/0003_adapterinstance_adapter_metadata_b.py deleted file mode 100644 index 75abf4ae1..000000000 --- a/backend/adapter_processor/migrations/0003_adapterinstance_adapter_metadata_b.py +++ /dev/null @@ -1,40 +0,0 @@ -# Generated by Django 4.2.1 on 2024-02-13 13:09 - -import json -from typing import Any - -from cryptography.fernet import Fernet -from django.conf import settings -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("adapter_processor", "0002_adapterinstance_unique_adapter"), - ] - - def EncryptCredentials(apps: Any, schema_editor: Any) -> None: - encryption_secret: str = settings.ENCRYPTION_KEY - f: Fernet = Fernet(encryption_secret.encode("utf-8")) - AdapterInstance = apps.get_model("adapter_processor", "AdapterInstance") - queryset = AdapterInstance.objects.all() - - for obj in queryset: # type: ignore - # Access attributes of the object - - print(f"Object ID: {obj.id}, Name: {obj.adapter_name}") - if hasattr(obj, "adapter_metadata"): - json_string: str = json.dumps(obj.adapter_metadata) - obj.adapter_metadata_b = f.encrypt(json_string.encode("utf-8")) - obj.save() - - operations = [ - migrations.AddField( - model_name="adapterinstance", - name="adapter_metadata_b", - field=models.BinaryField(null=True), - ), - migrations.RunPython( - EncryptCredentials, reverse_code=migrations.RunPython.noop - ), - ] diff --git a/backend/adapter_processor/migrations/0004_alter_adapterinstance_adapter_type.py b/backend/adapter_processor/migrations/0004_alter_adapterinstance_adapter_type.py deleted file mode 100644 index 6321d6219..000000000 --- a/backend/adapter_processor/migrations/0004_alter_adapterinstance_adapter_type.py +++ /dev/null @@ -1,26 +0,0 @@ -# Generated by Django 4.2.1 on 2024-02-23 09:29 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("adapter_processor", "0003_adapterinstance_adapter_metadata_b"), - ] - - operations = [ - migrations.AlterField( - model_name="adapterinstance", - name="adapter_type", - field=models.CharField( - choices=[ - ("UNKNOWN", "UNKNOWN"), - ("LLM", "LLM"), - ("EMBEDDING", "EMBEDDING"), - ("VECTOR_DB", "VECTOR_DB"), - ("X2TEXT", "X2TEXT"), - ], - db_comment="Type of adapter LLM/EMBEDDING/VECTOR_DB", - ), - ), - ] diff --git a/backend/adapter_processor/migrations/0005_alter_adapterinstance_adapter_type.py b/backend/adapter_processor/migrations/0005_alter_adapterinstance_adapter_type.py deleted file mode 100644 index c0631e723..000000000 --- a/backend/adapter_processor/migrations/0005_alter_adapterinstance_adapter_type.py +++ /dev/null @@ -1,27 +0,0 @@ -# Generated by Django 4.2.1 on 2024-02-28 09:03 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("adapter_processor", "0004_alter_adapterinstance_adapter_type"), - ] - - operations = [ - migrations.AlterField( - model_name="adapterinstance", - name="adapter_type", - field=models.CharField( - choices=[ - ("UNKNOWN", "UNKNOWN"), - ("LLM", "LLM"), - ("EMBEDDING", "EMBEDDING"), - ("VECTOR_DB", "VECTOR_DB"), - ("OCR", "OCR"), - ("X2TEXT", "X2TEXT"), - ], - db_comment="Type of adapter LLM/EMBEDDING/VECTOR_DB", - ), - ), - ] diff --git a/backend/adapter_processor/migrations/0006_adapterinstance_shared_users.py b/backend/adapter_processor/migrations/0006_adapterinstance_shared_users.py deleted file mode 100644 index 8bef69a85..000000000 --- a/backend/adapter_processor/migrations/0006_adapterinstance_shared_users.py +++ /dev/null @@ -1,21 +0,0 @@ -# Generated by Django 4.2.1 on 2024-03-11 07:55 - -from django.conf import settings -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - migrations.swappable_dependency(settings.AUTH_USER_MODEL), - ("adapter_processor", "0005_alter_adapterinstance_adapter_type"), - ] - - operations = [ - migrations.AddField( - model_name="adapterinstance", - name="shared_users", - field=models.ManyToManyField( - related_name="shared_adapters", to=settings.AUTH_USER_MODEL - ), - ), - ] diff --git a/backend/adapter_processor/migrations/0007_remove_adapterinstance_is_default_userdefaultadapter.py b/backend/adapter_processor/migrations/0007_remove_adapterinstance_is_default_userdefaultadapter.py deleted file mode 100644 index a493ad46a..000000000 --- a/backend/adapter_processor/migrations/0007_remove_adapterinstance_is_default_userdefaultadapter.py +++ /dev/null @@ -1,81 +0,0 @@ -# Generated by Django 4.2.1 on 2024-03-14 11:37 - -import django.db.models.deletion -from django.conf import settings -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - migrations.swappable_dependency(settings.AUTH_USER_MODEL), - ("adapter_processor", "0006_adapterinstance_shared_users"), - ] - - operations = [ - migrations.RemoveField( - model_name="adapterinstance", - name="is_default", - ), - migrations.CreateModel( - name="UserDefaultAdapter", - fields=[ - ( - "id", - models.BigAutoField( - auto_created=True, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - ), - ("created_at", models.DateTimeField(auto_now_add=True)), - ("modified_at", models.DateTimeField(auto_now=True)), - ( - "default_embedding_adapter", - models.ForeignKey( - null=True, - on_delete=django.db.models.deletion.SET_NULL, - related_name="default_embedding_adapter", - to="adapter_processor.adapterinstance", - ), - ), - ( - "default_llm_adapter", - models.ForeignKey( - null=True, - on_delete=django.db.models.deletion.SET_NULL, - related_name="default_llm_adapter", - to="adapter_processor.adapterinstance", - ), - ), - ( - "default_vector_db_adapter", - models.ForeignKey( - null=True, - on_delete=django.db.models.deletion.SET_NULL, - related_name="default_vector_db_adapter", - to="adapter_processor.adapterinstance", - ), - ), - ( - "default_x2text_adapter", - models.ForeignKey( - null=True, - on_delete=django.db.models.deletion.SET_NULL, - related_name="default_x2text_adapter", - to="adapter_processor.adapterinstance", - ), - ), - ( - "user", - models.OneToOneField( - on_delete=django.db.models.deletion.CASCADE, - to=settings.AUTH_USER_MODEL, - ), - ), - ], - options={ - "abstract": False, - }, - ), - ] diff --git a/backend/adapter_processor/migrations/0008_adapterinstance_description_and_more.py b/backend/adapter_processor/migrations/0008_adapterinstance_description_and_more.py deleted file mode 100644 index ec2c607c6..000000000 --- a/backend/adapter_processor/migrations/0008_adapterinstance_description_and_more.py +++ /dev/null @@ -1,41 +0,0 @@ -# Generated by Django 4.2.1 on 2024-04-29 05:16 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ( - "adapter_processor", - "0007_remove_adapterinstance_is_default_userdefaultadapter", - ), - ] - - operations = [ - migrations.AddField( - model_name="adapterinstance", - name="description", - field=models.TextField(blank=True, default=None, null=True), - ), - migrations.AddField( - model_name="adapterinstance", - name="is_friction_less", - field=models.BooleanField( - db_comment="Was the adapter created through frictionless onboarding", - default=False, - ), - ), - migrations.AddField( - model_name="adapterinstance", - name="is_usable", - field=models.BooleanField(db_comment="Is the Adpater Usable", default=True), - ), - migrations.AddField( - model_name="adapterinstance", - name="shared_to_org", - field=models.BooleanField( - db_comment="Is the adapter shared to entire org", default=False - ), - ), - ] diff --git a/backend/adapter_processor/migrations/0009_frictionless_adapter_data_migration.py b/backend/adapter_processor/migrations/0009_frictionless_adapter_data_migration.py deleted file mode 100644 index e4eb37167..000000000 --- a/backend/adapter_processor/migrations/0009_frictionless_adapter_data_migration.py +++ /dev/null @@ -1,71 +0,0 @@ -# Generated by Django 4.2.1 on 2024-06-02 04:17 - -import json -import logging -from typing import Any - -from adapter_processor.constants import AdapterKeys -from cryptography.fernet import Fernet -from django.conf import settings -from django.db import migrations - -logger = logging.getLogger(__name__) - - -class Migration(migrations.Migration): - - dependencies = [ - ("adapter_processor", "0008_adapterinstance_description_and_more"), - ] - - def migrate_frictionless_adapter_data(apps: Any, schema_editor: Any) -> None: - """Migrates the data for the frictionless adapter by encrypting and - saving the adapter metadata in the database. - - Parameters: - apps (Any): The registry of installed applications. - schema_editor (Any): The schema editor for the database operation. - - Returns: - None: This method does not return anything. - - Raises: - None: This method does not raise any exceptions. - """ - VECTOR_DB_CONF = "VECTOR_DB_CONF" - - # check vector db conf is added in env - if not hasattr(settings, VECTOR_DB_CONF): - return - - AdapterInstance = apps.get_model("adapter_processor", "AdapterInstance") - - try: - adapter_instance = AdapterInstance.objects.get( - is_friction_less=True, - is_usable=True, - adapter_type=AdapterKeys.VECTOR_DB, - ) - - vector_db_conf = json.loads(getattr(settings, VECTOR_DB_CONF)) - adapter_metadata = vector_db_conf["adapter_metadata"] - - encryption_secret: str = settings.ENCRYPTION_KEY - f: Fernet = Fernet(encryption_secret.encode("utf-8")) - - adapter_metadata_b: bytes = f.encrypt( - json.dumps(adapter_metadata).encode("utf-8") - ) - - adapter_instance.adapter_metadata_b = adapter_metadata_b - adapter_instance.save() - - except AdapterInstance.DoesNotExist: - logger.info("Skip data migration as frictionless vector-db not available") - - operations = [ - migrations.RunPython( - migrate_frictionless_adapter_data, - reverse_code=migrations.RunPython.noop, - ), - ] diff --git a/backend/adapter_processor/migrations/__init__.py b/backend/adapter_processor/migrations/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/backend/adapter_processor/models.py b/backend/adapter_processor/models.py deleted file mode 100644 index f856bcf6b..000000000 --- a/backend/adapter_processor/models.py +++ /dev/null @@ -1,189 +0,0 @@ -import json -import logging -import uuid -from typing import Any - -from account.models import User -from cryptography.fernet import Fernet, InvalidToken -from django.conf import settings -from django.db import models -from django.db.models import QuerySet -from unstract.sdk.adapters.adapterkit import Adapterkit -from unstract.sdk.adapters.enums import AdapterTypes -from unstract.sdk.adapters.exceptions import AdapterError -from utils.exceptions import InvalidEncryptionKey -from utils.models.base_model import BaseModel - -ADAPTER_NAME_SIZE = 128 -VERSION_NAME_SIZE = 64 -ADAPTER_ID_LENGTH = 128 - -logger = logging.getLogger(__name__) - - -class AdapterInstanceModelManager(models.Manager): - def get_queryset(self) -> QuerySet[Any]: - return super().get_queryset() - - def for_user(self, user: User) -> QuerySet[Any]: - return ( - self.get_queryset() - .filter( - models.Q(created_by=user) - | models.Q(shared_users=user) - | models.Q(shared_to_org=True) - | models.Q(is_friction_less=True) - ) - .distinct("id") - ) - - -class AdapterInstance(BaseModel): - id = models.UUIDField( - primary_key=True, - default=uuid.uuid4, - editable=False, - db_comment="Unique identifier for the Adapter Instance", - ) - adapter_name = models.TextField( - max_length=ADAPTER_NAME_SIZE, - null=False, - blank=False, - db_comment="Name of the Adapter Instance", - ) - adapter_id = models.CharField( - max_length=ADAPTER_ID_LENGTH, - default="", - db_comment="Unique identifier of the Adapter", - ) - - # TODO to be removed once the migration for encryption - adapter_metadata = models.JSONField( - db_column="adapter_metadata", - null=False, - blank=False, - default=dict, - db_comment="JSON adapter metadata submitted by the user", - ) - adapter_metadata_b = models.BinaryField(null=True) - adapter_type = models.CharField( - choices=[(tag.value, tag.name) for tag in AdapterTypes], - db_comment="Type of adapter LLM/EMBEDDING/VECTOR_DB", - ) - created_by = models.ForeignKey( - User, - on_delete=models.SET_NULL, - related_name="created_adapters", - null=True, - blank=True, - ) - modified_by = models.ForeignKey( - User, - on_delete=models.SET_NULL, - related_name="modified_adapters", - null=True, - blank=True, - ) - - is_active = models.BooleanField( - default=False, - db_comment="Is the adapter instance currently being used", - ) - shared_to_org = models.BooleanField( - default=False, - db_comment="Is the adapter shared to entire org", - ) - - is_friction_less = models.BooleanField( - default=False, - db_comment="Was the adapter created through frictionless onboarding", - ) - - # Can be used if the adapter usage gets exhausted - # Can also be used in other possible scenarios in feature - is_usable = models.BooleanField( - default=True, - db_comment="Is the Adpater Usable", - ) - - # Introduced field to establish M2M relation between users and adapters. - # This will introduce intermediary table which relates both the models. - shared_users = models.ManyToManyField(User, related_name="shared_adapters") - description = models.TextField(blank=True, null=True, default=None) - - objects = AdapterInstanceModelManager() - - class Meta: - verbose_name = "adapter_adapterinstance" - verbose_name_plural = "adapter_adapterinstance" - db_table = "adapter_adapterinstance" - constraints = [ - models.UniqueConstraint( - fields=["adapter_name", "adapter_type"], - name="unique_adapter", - ), - ] - - def create_adapter(self) -> None: - - encryption_secret: str = settings.ENCRYPTION_KEY - f: Fernet = Fernet(encryption_secret.encode("utf-8")) - - self.adapter_metadata_b = f.encrypt( - json.dumps(self.adapter_metadata).encode("utf-8") - ) - self.adapter_metadata = {} - - self.save() - - @property - def metadata(self) -> Any: - try: - encryption_secret: str = settings.ENCRYPTION_KEY - f: Fernet = Fernet(encryption_secret.encode("utf-8")) - - adapter_metadata = json.loads( - f.decrypt(bytes(self.adapter_metadata_b).decode("utf-8")) - ) - except InvalidToken: - raise InvalidEncryptionKey(entity=InvalidEncryptionKey.Entity.ADAPTER) - return adapter_metadata - - def get_context_window_size(self) -> int: - # Get the adapter_instance - adapter_class = Adapterkit().get_adapter_class_by_adapter_id(self.adapter_id) - try: - adapter_instance = adapter_class(self.metadata) - return adapter_instance.get_context_window_size() - except AdapterError as e: - logger.warning(f"Unable to retrieve context window size - {e}") - return 0 - - -class UserDefaultAdapter(BaseModel): - user = models.OneToOneField(User, on_delete=models.CASCADE) - default_llm_adapter = models.ForeignKey( - AdapterInstance, - on_delete=models.SET_NULL, - null=True, - related_name="default_llm_adapter", - ) - default_embedding_adapter = models.ForeignKey( - AdapterInstance, - on_delete=models.SET_NULL, - null=True, - related_name="default_embedding_adapter", - ) - default_vector_db_adapter = models.ForeignKey( - AdapterInstance, - on_delete=models.SET_NULL, - null=True, - related_name="default_vector_db_adapter", - ) - - default_x2text_adapter = models.ForeignKey( - AdapterInstance, - on_delete=models.SET_NULL, - null=True, - related_name="default_x2text_adapter", - ) diff --git a/backend/adapter_processor/serializers.py b/backend/adapter_processor/serializers.py deleted file mode 100644 index 1e74c19ad..000000000 --- a/backend/adapter_processor/serializers.py +++ /dev/null @@ -1,158 +0,0 @@ -import json -from typing import Any - -from account.serializer import UserSerializer -from adapter_processor.adapter_processor import AdapterProcessor -from adapter_processor.constants import AdapterKeys -from cryptography.fernet import Fernet -from django.conf import settings -from rest_framework import serializers -from rest_framework.serializers import ModelSerializer -from unstract.sdk.adapters.constants import Common as common -from unstract.sdk.adapters.enums import AdapterTypes - -from backend.constants import FieldLengthConstants as FLC -from backend.serializers import AuditSerializer - -from .models import AdapterInstance, UserDefaultAdapter - - -class TestAdapterSerializer(serializers.Serializer): - adapter_id = serializers.CharField(max_length=FLC.ADAPTER_ID_LENGTH) - adapter_metadata = serializers.JSONField() - adapter_type = serializers.JSONField() - - -class BaseAdapterSerializer(AuditSerializer): - class Meta: - model = AdapterInstance - fields = "__all__" - - -class DefaultAdapterSerializer(serializers.Serializer): - llm_default = serializers.CharField(max_length=FLC.UUID_LENGTH, required=False) - embedding_default = serializers.CharField( - max_length=FLC.UUID_LENGTH, required=False - ) - vector_db_default = serializers.CharField( - max_length=FLC.UUID_LENGTH, required=False - ) - - -class AdapterInstanceSerializer(BaseAdapterSerializer): - """Inherits BaseAdapterSerializer. - - Used for CRUD other than listing - """ - - def to_internal_value(self, data: dict[str, Any]) -> dict[str, Any]: - if data.get(AdapterKeys.ADAPTER_METADATA, None): - encryption_secret: str = settings.ENCRYPTION_KEY - f: Fernet = Fernet(encryption_secret.encode("utf-8")) - json_string: str = json.dumps(data.pop(AdapterKeys.ADAPTER_METADATA)) - - data[AdapterKeys.ADAPTER_METADATA_B] = f.encrypt( - json_string.encode("utf-8") - ) - - return data - - def to_representation(self, instance: AdapterInstance) -> dict[str, str]: - rep: dict[str, str] = super().to_representation(instance) - - rep.pop(AdapterKeys.ADAPTER_METADATA_B) - adapter_metadata = instance.metadata - rep[AdapterKeys.ADAPTER_METADATA] = adapter_metadata - # Retrieve context window if adapter is a LLM - # For other adapter types, context_window is not relevant. - if instance.adapter_type == AdapterTypes.LLM.value: - adapter_metadata[AdapterKeys.ADAPTER_CONTEXT_WINDOW_SIZE] = ( - instance.get_context_window_size() - ) - - rep[common.ICON] = AdapterProcessor.get_adapter_data_with_key( - instance.adapter_id, common.ICON - ) - rep[AdapterKeys.ADAPTER_CREATED_BY] = instance.created_by.email - - return rep - - -class AdapterInfoSerializer(BaseAdapterSerializer): - - context_window_size = serializers.SerializerMethodField() - - class Meta(BaseAdapterSerializer.Meta): - model = AdapterInstance - fields = ( - "id", - "adapter_id", - "adapter_name", - "adapter_type", - "created_by", - "context_window_size", - ) # type: ignore - - def get_context_window_size(self, obj: AdapterInstance) -> int: - return obj.get_context_window_size() - - -class AdapterListSerializer(BaseAdapterSerializer): - """Inherits BaseAdapterSerializer. - - Used for listing adapters - """ - - class Meta(BaseAdapterSerializer.Meta): - model = AdapterInstance - fields = ( - "id", - "adapter_id", - "adapter_name", - "adapter_type", - "created_by", - "description", - ) # type: ignore - - def to_representation(self, instance: AdapterInstance) -> dict[str, str]: - rep: dict[str, str] = super().to_representation(instance) - rep[common.ICON] = AdapterProcessor.get_adapter_data_with_key( - instance.adapter_id, common.ICON - ) - model = instance.metadata.get("model") - if model: - rep["model"] = model - - if instance.is_friction_less: - rep["created_by_email"] = "Unstract" - else: - rep["created_by_email"] = instance.created_by.email - - return rep - - -class SharedUserListSerializer(BaseAdapterSerializer): - """Inherits BaseAdapterSerializer. - - Used for listing adapter users - """ - - shared_users = UserSerializer(many=True) - created_by = UserSerializer() - - class Meta(BaseAdapterSerializer.Meta): - model = AdapterInstance - fields = ( - "id", - "adapter_id", - "adapter_name", - "adapter_type", - "created_by", - "shared_users", - ) # type: ignore - - -class UserDefaultAdapterSerializer(ModelSerializer): - class Meta: - model = UserDefaultAdapter - fields = "__all__" diff --git a/backend/adapter_processor/urls.py b/backend/adapter_processor/urls.py deleted file mode 100644 index 215e4e6c9..000000000 --- a/backend/adapter_processor/urls.py +++ /dev/null @@ -1,42 +0,0 @@ -from adapter_processor.views import ( - AdapterInstanceViewSet, - AdapterViewSet, - DefaultAdapterViewSet, -) -from django.urls import path -from rest_framework.urlpatterns import format_suffix_patterns - -default_triad = DefaultAdapterViewSet.as_view( - {"post": "configure_default_triad", "get": "get_default_triad"} -) -adapter = AdapterViewSet.as_view({"get": "list"}) -adapter_schema = AdapterViewSet.as_view({"get": "get_adapter_schema"}) -adapter_test = AdapterViewSet.as_view({"post": "test"}) -adapter_list = AdapterInstanceViewSet.as_view({"post": "create", "get": "list"}) -adapter_detail = AdapterInstanceViewSet.as_view( - { - "get": "retrieve", - "put": "update", - "patch": "partial_update", - "delete": "destroy", - } -) - -adapter_users = AdapterInstanceViewSet.as_view({"get": "list_of_shared_users"}) -adapter_info = AdapterInstanceViewSet.as_view({"get": "adapter_info"}) -urlpatterns = format_suffix_patterns( - [ - path("adapter_schema/", adapter_schema, name="get_adapter_schema"), - path("supported_adapters/", adapter, name="adapter-list"), - path("adapter/", adapter_list, name="adapter-list"), - path("adapter/default_triad/", default_triad, name="default_triad"), - path("adapter//", adapter_detail, name="adapter_detail"), - path("adapter/info//", adapter_info, name="adapter_info"), - path("test_adapters/", adapter_test, name="adapter-test"), - path( - "adapter/users//", - adapter_users, - name="adapter-users", - ), - ] -) diff --git a/backend/adapter_processor/views.py b/backend/adapter_processor/views.py deleted file mode 100644 index b262dd04d..000000000 --- a/backend/adapter_processor/views.py +++ /dev/null @@ -1,309 +0,0 @@ -import logging -import uuid -from typing import Any, Optional - -from adapter_processor.adapter_processor import AdapterProcessor -from adapter_processor.constants import AdapterKeys -from adapter_processor.exceptions import ( - CannotDeleteDefaultAdapter, - DeleteAdapterInUseError, - DuplicateAdapterNameError, - IdIsMandatory, - InValidType, -) -from adapter_processor.serializers import ( - AdapterInfoSerializer, - AdapterInstanceSerializer, - AdapterListSerializer, - DefaultAdapterSerializer, - SharedUserListSerializer, - TestAdapterSerializer, - UserDefaultAdapterSerializer, -) -from django.db import IntegrityError -from django.db.models import ProtectedError, QuerySet -from django.http import HttpRequest -from django.http.response import HttpResponse -from permissions.permission import ( - IsFrictionLessAdapter, - IsFrictionLessAdapterDelete, - IsOwner, - IsOwnerOrSharedUserOrSharedToOrg, -) -from rest_framework import status -from rest_framework.decorators import action -from rest_framework.request import Request -from rest_framework.response import Response -from rest_framework.serializers import ModelSerializer -from rest_framework.versioning import URLPathVersioning -from rest_framework.viewsets import GenericViewSet, ModelViewSet -from utils.filtering import FilterHelper - -from .constants import AdapterKeys as constant -from .models import AdapterInstance, UserDefaultAdapter - -logger = logging.getLogger(__name__) - - -class DefaultAdapterViewSet(ModelViewSet): - versioning_class = URLPathVersioning - serializer_class = DefaultAdapterSerializer - - def configure_default_triad( - self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any] - ) -> HttpResponse: - serializer = self.get_serializer(data=request.data) - serializer.is_valid(raise_exception=True) - # Convert request data to json - default_triad = request.data - AdapterProcessor.set_default_triad(default_triad, request.user) - return Response(status=status.HTTP_200_OK) - - def get_default_triad( - self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any] - ) -> HttpResponse: - try: - user_default_adapter = UserDefaultAdapter.objects.get(user=request.user) - serializer = UserDefaultAdapterSerializer(user_default_adapter).data - return Response(serializer) - - except UserDefaultAdapter.DoesNotExist: - # Handle the case when no records are found - return Response(status=status.HTTP_200_OK, data={}) - - -class AdapterViewSet(GenericViewSet): - versioning_class = URLPathVersioning - serializer_class = TestAdapterSerializer - - def list( - self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any] - ) -> HttpResponse: - if request.method == "GET": - adapter_type = request.GET.get(AdapterKeys.ADAPTER_TYPE) - if ( - adapter_type == AdapterKeys.LLM - or adapter_type == AdapterKeys.EMBEDDING - or adapter_type == AdapterKeys.VECTOR_DB - or adapter_type == AdapterKeys.X2TEXT - or adapter_type == AdapterKeys.OCR - ): - json_schema = AdapterProcessor.get_all_supported_adapters( - type=adapter_type - ) - return Response(json_schema, status=status.HTTP_200_OK) - else: - raise InValidType - - def get_adapter_schema( - self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any] - ) -> HttpResponse: - if request.method == "GET": - adapter_name = request.GET.get(AdapterKeys.ID) - if adapter_name is None or adapter_name == "": - raise IdIsMandatory() - json_schema = AdapterProcessor.get_json_schema(adapter_id=adapter_name) - return Response(data=json_schema, status=status.HTTP_200_OK) - - def test(self, request: Request) -> Response: - """Tests the connector against the credentials passed.""" - serializer: AdapterInstanceSerializer = self.get_serializer(data=request.data) - serializer.is_valid(raise_exception=True) - adapter_id = serializer.validated_data.get(AdapterKeys.ADAPTER_ID) - adapter_metadata = serializer.validated_data.get(AdapterKeys.ADAPTER_METADATA) - adapter_metadata[AdapterKeys.ADAPTER_TYPE] = serializer.validated_data.get( - AdapterKeys.ADAPTER_TYPE - ) - test_result = AdapterProcessor.test_adapter( - adapter_id=adapter_id, adapter_metadata=adapter_metadata - ) - return Response( - {AdapterKeys.IS_VALID: test_result}, - status=status.HTTP_200_OK, - ) - - -class AdapterInstanceViewSet(ModelViewSet): - - serializer_class = AdapterInstanceSerializer - - def get_permissions(self) -> list[Any]: - - if self.action in ["update", "retrieve"]: - return [IsFrictionLessAdapter()] - - elif self.action == "destroy": - return [IsFrictionLessAdapterDelete()] - - elif self.action in ["list_of_shared_users", "adapter_info"]: - return [IsOwnerOrSharedUserOrSharedToOrg()] - - # Hack for friction-less onboarding - # User cant view/update metadata but can delete/share etc - return [IsOwner()] - - def get_queryset(self) -> Optional[QuerySet]: - if filter_args := FilterHelper.build_filter_args( - self.request, - constant.ADAPTER_TYPE, - ): - queryset = AdapterInstance.objects.for_user(self.request.user).filter( - **filter_args - ) - else: - queryset = AdapterInstance.objects.for_user(self.request.user) - return queryset - - def get_serializer_class( - self, - ) -> ModelSerializer: - if self.action == "list": - return AdapterListSerializer - return AdapterInstanceSerializer - - def create(self, request: Any) -> Response: - serializer = self.get_serializer(data=request.data) - serializer.is_valid(raise_exception=True) - try: - instance = serializer.save() - - # Check to see if there is a default configured - # for this adapter_type and for the current user - ( - user_default_adapter, - created, - ) = UserDefaultAdapter.objects.get_or_create(user=request.user) - - adapter_type = serializer.validated_data.get(AdapterKeys.ADAPTER_TYPE) - if (adapter_type == AdapterKeys.LLM) and ( - not user_default_adapter.default_llm_adapter - ): - user_default_adapter.default_llm_adapter = instance - - elif (adapter_type == AdapterKeys.EMBEDDING) and ( - not user_default_adapter.default_embedding_adapter - ): - user_default_adapter.default_embedding_adapter = instance - elif (adapter_type == AdapterKeys.VECTOR_DB) and ( - not user_default_adapter.default_vector_db_adapter - ): - user_default_adapter.default_vector_db_adapter = instance - elif (adapter_type == AdapterKeys.X2TEXT) and ( - not user_default_adapter.default_x2text_adapter - ): - user_default_adapter.default_x2text_adapter = instance - - user_default_adapter.save() - - except IntegrityError: - raise DuplicateAdapterNameError( - name=serializer.validated_data.get(AdapterKeys.ADAPTER_NAME) - ) - headers = self.get_success_headers(serializer.data) - return Response( - serializer.data, status=status.HTTP_201_CREATED, headers=headers - ) - - def destroy( - self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any] - ) -> Response: - adapter_instance: AdapterInstance = self.get_object() - adapter_type = adapter_instance.adapter_type - try: - user_default_adapter: UserDefaultAdapter = UserDefaultAdapter.objects.get( - user=request.user - ) - - if ( - ( - adapter_type == AdapterKeys.LLM - and adapter_instance == user_default_adapter.default_llm_adapter - ) - or ( - adapter_type == AdapterKeys.EMBEDDING - and adapter_instance - == user_default_adapter.default_embedding_adapter - ) - or ( - adapter_type == AdapterKeys.VECTOR_DB - and adapter_instance - == user_default_adapter.default_vector_db_adapter - ) - or ( - adapter_type == AdapterKeys.X2TEXT - and adapter_instance == user_default_adapter.default_x2text_adapter - ) - ): - logger.error("Cannot delete a default adapter") - raise CannotDeleteDefaultAdapter() - except UserDefaultAdapter.DoesNotExist: - # We can go head and remove adapter here - logger.info("User default adpater doesnt not exist") - - try: - super().perform_destroy(adapter_instance) - except ProtectedError: - logger.error( - f"Failed to delete adapter: {adapter_instance.adapter_id}" - f" named {adapter_instance.adapter_name}" - ) - # TODO: Provide details of adpter usage with exception object - raise DeleteAdapterInUseError(adapter_name=adapter_instance.adapter_name) - return Response(status=status.HTTP_204_NO_CONTENT) - - def partial_update( - self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any] - ) -> Response: - if AdapterKeys.SHARED_USERS in request.data: - # find the deleted users - adapter = self.get_object() - shared_users = { - int(user_id) for user_id in request.data.get("shared_users", {}) - } - current_users = {user.id for user in adapter.shared_users.all()} - removed_users = current_users.difference(shared_users) - - # if removed user use this adapter as default - # Remove the same from his default - for user_id in removed_users: - try: - user_default_adapter = UserDefaultAdapter.objects.get( - user_id=user_id - ) - - if user_default_adapter.default_llm_adapter == adapter: - user_default_adapter.default_llm_adapter = None - elif user_default_adapter.default_embedding_adapter == adapter: - user_default_adapter.default_embedding_adapter = None - elif user_default_adapter.default_vector_db_adapter == adapter: - user_default_adapter.default_vector_db_adapter = None - elif user_default_adapter.default_x2text_adapter == adapter: - user_default_adapter.default_x2text_adapter = None - - user_default_adapter.save() - except UserDefaultAdapter.DoesNotExist: - logger.debug( - "User id : %s doesnt have default adapters configured", - user_id, - ) - continue - - return super().partial_update(request, *args, **kwargs) - - @action(detail=True, methods=["get"]) - def list_of_shared_users(self, request: HttpRequest, pk: Any = None) -> Response: - - adapter = self.get_object() - - serialized_instances = SharedUserListSerializer(adapter).data - - return Response(serialized_instances) - - @action(detail=True, methods=["get"]) - def adapter_info(self, request: HttpRequest, pk: uuid) -> Response: - - adapter = self.get_object() - - serialized_instances = AdapterInfoSerializer(adapter).data - - return Response(serialized_instances) diff --git a/backend/api/__init__.py b/backend/api/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/backend/api/admin.py b/backend/api/admin.py deleted file mode 100644 index 37f0837a7..000000000 --- a/backend/api/admin.py +++ /dev/null @@ -1,5 +0,0 @@ -from django.contrib import admin - -from .models import APIDeployment, APIKey - -admin.site.register([APIDeployment, APIKey]) diff --git a/backend/api/api_deployment_views.py b/backend/api/api_deployment_views.py deleted file mode 100644 index 99749988e..000000000 --- a/backend/api/api_deployment_views.py +++ /dev/null @@ -1,156 +0,0 @@ -import json -import logging -from typing import Any, Optional - -from api.constants import ApiExecution -from api.deployment_helper import DeploymentHelper -from api.exceptions import InvalidAPIRequest, NoActiveAPIKeyError -from api.models import APIDeployment -from api.postman_collection.dto import PostmanCollection -from api.serializers import ( - APIDeploymentListSerializer, - APIDeploymentSerializer, - DeploymentResponseSerializer, - ExecutionRequestSerializer, -) -from django.db.models import QuerySet -from django.http import HttpResponse -from permissions.permission import IsOwner -from rest_framework import serializers, status, views, viewsets -from rest_framework.decorators import action -from rest_framework.request import Request -from rest_framework.response import Response -from rest_framework.serializers import Serializer -from utils.enums import CeleryTaskState -from workflow_manager.workflow.dto import ExecutionResponse - -logger = logging.getLogger(__name__) - - -class DeploymentExecution(views.APIView): - def initialize_request( - self, request: Request, *args: Any, **kwargs: Any - ) -> Request: - """To remove csrf request for public API. - - Args: - request (Request): _description_ - - Returns: - Request: _description_ - """ - setattr(request, "csrf_processing_done", True) - return super().initialize_request(request, *args, **kwargs) - - @DeploymentHelper.validate_api_key - def post( - self, request: Request, org_name: str, api_name: str, api: APIDeployment - ) -> Response: - file_objs = request.FILES.getlist(ApiExecution.FILES_FORM_DATA) - serializer = ExecutionRequestSerializer(data=request.data) - serializer.is_valid(raise_exception=True) - timeout = serializer.validated_data.get(ApiExecution.TIMEOUT_FORM_DATA) - include_metadata = serializer.validated_data.get(ApiExecution.INCLUDE_METADATA) - use_file_history = serializer.validated_data.get(ApiExecution.USE_FILE_HISTORY) - if not file_objs or len(file_objs) == 0: - raise InvalidAPIRequest("File shouldn't be empty") - response = DeploymentHelper.execute_workflow( - organization_name=org_name, - api=api, - file_objs=file_objs, - timeout=timeout, - include_metadata=include_metadata, - use_file_history=use_file_history, - ) - if "error" in response and response["error"]: - return Response( - {"message": response}, - status=status.HTTP_422_UNPROCESSABLE_ENTITY, - ) - return Response({"message": response}, status=status.HTTP_200_OK) - - @DeploymentHelper.validate_api_key - def get( - self, request: Request, org_name: str, api_name: str, api: APIDeployment - ) -> Response: - execution_id = request.query_params.get("execution_id") - include_metadata = ( - request.query_params.get(ApiExecution.INCLUDE_METADATA, "false").lower() - == "true" - ) - if not execution_id: - raise InvalidAPIRequest("execution_id shouldn't be empty") - response: ExecutionResponse = DeploymentHelper.get_execution_status( - execution_id=execution_id - ) - response_status = status.HTTP_422_UNPROCESSABLE_ENTITY - if response.execution_status == CeleryTaskState.COMPLETED.value: - response_status = status.HTTP_200_OK - if not include_metadata: - response.remove_result_metadata_keys() - return Response( - data={ - "status": response.execution_status, - "message": response.result, - }, - status=response_status, - ) - - -class APIDeploymentViewSet(viewsets.ModelViewSet): - permission_classes = [IsOwner] - - def get_queryset(self) -> Optional[QuerySet]: - return APIDeployment.objects.filter(created_by=self.request.user) - - def get_serializer_class(self) -> serializers.Serializer: - if self.action in ["list"]: - return APIDeploymentListSerializer - return APIDeploymentSerializer - - @action(detail=True, methods=["get"]) - def fetch_one(self, request: Request, pk: Optional[str] = None) -> Response: - """Custom action to fetch a single instance.""" - instance = self.get_object() - serializer = self.get_serializer(instance) - return Response(serializer.data) - - def create( - self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any] - ) -> Response: - serializer: Serializer = self.get_serializer(data=request.data) - serializer.is_valid(raise_exception=True) - self.perform_create(serializer) - api_key = DeploymentHelper.create_api_key(serializer=serializer) - response_serializer = DeploymentResponseSerializer( - {"api_key": api_key.api_key, **serializer.data} - ) - - headers = self.get_success_headers(serializer.data) - return Response( - response_serializer.data, - status=status.HTTP_201_CREATED, - headers=headers, - ) - - @action(detail=True, methods=["get"]) - def download_postman_collection( - self, request: Request, pk: Optional[str] = None - ) -> Response: - """Downloads a Postman Collection of the API deployment instance.""" - instance = self.get_object() - api_key_inst = instance.apikey_set.filter(is_active=True).first() - if not api_key_inst: - logger.error(f"No active API key set for deployment {instance.pk}") - raise NoActiveAPIKeyError(deployment_name=instance.display_name) - - postman_collection = PostmanCollection.create( - instance=instance, api_key=api_key_inst.api_key - ) - response = HttpResponse( - json.dumps(postman_collection.to_dict()), content_type="application/json" - ) - response["Content-Disposition"] = ( - f'attachment; filename="{instance.display_name}.json"' - ) - return response diff --git a/backend/api/api_key_validator.py b/backend/api/api_key_validator.py deleted file mode 100644 index 5883b2e11..000000000 --- a/backend/api/api_key_validator.py +++ /dev/null @@ -1,65 +0,0 @@ -import logging -from functools import wraps -from typing import Any - -from api.exceptions import Forbidden -from django_tenants.utils import get_tenant_model, tenant_context -from rest_framework.request import Request - -logger = logging.getLogger(__name__) - - -class BaseAPIKeyValidator: - @classmethod - def validate_api_key(cls, func: Any) -> Any: - """Decorator that validates the API key. - - Sample header: - Authorization: Bearer 123e4567-e89b-12d3-a456-426614174001 - Args: - func (Any): Function to wrap for validation - """ - - @wraps(func) - def wrapper(self: Any, request: Request, *args: Any, **kwargs: Any) -> Any: - """Wrapper to validate the inputs and key. - - Args: - request (Request): Request context - - Raises: - Forbidden: _description_ - APINotFound: _description_ - - Returns: - Any: _description_ - """ - authorization_header = request.headers.get("Authorization") - api_key = None - if authorization_header and authorization_header.startswith("Bearer "): - api_key = authorization_header.split(" ")[1] - if not api_key: - raise Forbidden("Missing api key") - org_name = kwargs.get("org_name") or request.data.get("org_name") - cls.validate_parameters(request, **kwargs) - tenant = get_tenant_model().objects.get(schema_name=org_name) - with tenant_context(tenant): - # Call the method to handle the specific validation and processing - return cls.validate_and_process( - self, request, func, *args, **kwargs, api_key=api_key - ) - - return wrapper - - @staticmethod - def validate_parameters(request: Request, **kwargs: Any) -> None: - """Validate specific parameters required by subclasses.""" - pass - - @staticmethod - def validate_and_process( - self: Any, request: Request, func: Any, api_key: str, *args: Any, **kwargs: Any - ) -> Any: - """Process and validate API key with specific logic required by - subclasses.""" - pass diff --git a/backend/api/api_key_views.py b/backend/api/api_key_views.py deleted file mode 100644 index c7c81e9c3..000000000 --- a/backend/api/api_key_views.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import Optional - -from api.deployment_helper import DeploymentHelper -from api.exceptions import APINotFound, PathVariablesNotFound -from api.key_helper import KeyHelper -from api.models import APIKey -from api.serializers import APIKeyListSerializer, APIKeySerializer -from permissions.permission import IsOwner -from pipeline.exceptions import PipelineNotFound -from pipeline.pipeline_processor import PipelineProcessor -from rest_framework import serializers, viewsets -from rest_framework.decorators import action -from rest_framework.request import Request -from rest_framework.response import Response - - -class APIKeyViewSet(viewsets.ModelViewSet): - queryset = APIKey.objects.all() - permission_classes = [IsOwner] - - def get_serializer_class(self) -> serializers.Serializer: - if self.action in ["api_keys"]: - return APIKeyListSerializer - return APIKeySerializer - - @action(detail=True, methods=["get"]) - def api_keys( - self, - request: Request, - api_id: Optional[str] = None, - pipeline_id: Optional[str] = None, - ) -> Response: - """Custom action to fetch api keys of an api deployment.""" - if api_id: - api = DeploymentHelper.get_api_by_id(api_id=api_id) - if not api: - raise APINotFound() - self.check_object_permissions(request, api) - keys = KeyHelper.list_api_keys_of_api(api_instance=api) - elif pipeline_id: - pipeline = PipelineProcessor.get_active_pipeline(pipeline_id=pipeline_id) - if not pipeline: - raise PipelineNotFound() - self.check_object_permissions(request, pipeline) - keys = KeyHelper.list_api_keys_of_pipeline(pipeline_instance=pipeline) - else: - raise PathVariablesNotFound( - "Either `api_id` or `pipeline_id` path variable must be provided." - ) - serializer = self.get_serializer(keys, many=True) - return Response(serializer.data) diff --git a/backend/api/apps.py b/backend/api/apps.py deleted file mode 100644 index 14b89a829..000000000 --- a/backend/api/apps.py +++ /dev/null @@ -1,5 +0,0 @@ -from django.apps import AppConfig - - -class ApiConfig(AppConfig): - name = "api" diff --git a/backend/api/constants.py b/backend/api/constants.py deleted file mode 100644 index 2de41efb2..000000000 --- a/backend/api/constants.py +++ /dev/null @@ -1,7 +0,0 @@ -class ApiExecution: - PATH: str = "deployment/api" - MAXIMUM_TIMEOUT_IN_SEC: int = 300 # 5 minutes - FILES_FORM_DATA: str = "files" - TIMEOUT_FORM_DATA: str = "timeout" - INCLUDE_METADATA: str = "include_metadata" - USE_FILE_HISTORY: str = "use_file_history" # Undocumented parameter diff --git a/backend/api/deployment_helper.py b/backend/api/deployment_helper.py deleted file mode 100644 index c301f56d1..000000000 --- a/backend/api/deployment_helper.py +++ /dev/null @@ -1,217 +0,0 @@ -import logging -import uuid -from typing import Any, Optional -from urllib.parse import urlencode - -from api.api_key_validator import BaseAPIKeyValidator -from api.constants import ApiExecution -from api.exceptions import ( - ApiKeyCreateException, - APINotFound, - InactiveAPI, - InvalidAPIRequest, -) -from api.key_helper import KeyHelper -from api.models import APIDeployment, APIKey -from api.serializers import APIExecutionResponseSerializer -from api.utils import APIDeploymentUtils -from django.core.files.uploadedfile import UploadedFile -from django.db import connection -from rest_framework.request import Request -from rest_framework.serializers import Serializer -from rest_framework.utils.serializer_helpers import ReturnDict -from utils.constants import CeleryQueue -from workflow_manager.endpoint.destination import DestinationConnector -from workflow_manager.endpoint.source import SourceConnector -from workflow_manager.workflow.dto import ExecutionResponse -from workflow_manager.workflow.enums import ExecutionStatus -from workflow_manager.workflow.models.workflow import Workflow -from workflow_manager.workflow.workflow_helper import WorkflowHelper - -logger = logging.getLogger(__name__) - - -class DeploymentHelper(BaseAPIKeyValidator): - @staticmethod - def validate_parameters(request: Request, **kwargs: Any) -> None: - """Validate api_name for API deployments.""" - api_name = kwargs.get("api_name") or request.data.get("api_name") - if not api_name: - raise InvalidAPIRequest("Missing params api_name") - - @staticmethod - def validate_and_process( - self: Any, request: Request, func: Any, api_key: str, *args: Any, **kwargs: Any - ) -> Any: - """Fetch API deployment and validate API key.""" - api_name = kwargs.get("api_name") or request.data.get("api_name") - api_deployment = DeploymentHelper.get_deployment_by_api_name(api_name=api_name) - DeploymentHelper.validate_api(api_deployment=api_deployment, api_key=api_key) - kwargs["api"] = api_deployment - return func(self, request, *args, **kwargs) - - @staticmethod - def validate_api(api_deployment: Optional[APIDeployment], api_key: str) -> None: - """Validating API and API key. - - Args: - api_deployment (Optional[APIDeployment]): _description_ - api_key (str): _description_ - - Raises: - APINotFound: _description_ - InactiveAPI: _description_ - """ - if not api_deployment: - raise APINotFound() - if not api_deployment.is_active: - raise InactiveAPI() - KeyHelper.validate_api_key(api_key=api_key, instance=api_deployment) - - @staticmethod - def validate_and_get_workflow(workflow_id: str) -> Workflow: - """Validate that the specified workflow_id exists in the Workflow - model.""" - return WorkflowHelper.get_workflow_by_id(workflow_id) - - @staticmethod - def get_api_by_id(api_id: str) -> Optional[APIDeployment]: - return APIDeploymentUtils.get_api_by_id(api_id=api_id) - - @staticmethod - def construct_complete_endpoint(api_name: str) -> str: - """Constructs the complete API endpoint by appending organization - schema, endpoint path, and Django app backend URL. - - Parameters: - - endpoint (str): The endpoint path to be appended to the complete URL. - - Returns: - - str: The complete API endpoint URL. - """ - org_schema = connection.tenant.schema_name - return f"{ApiExecution.PATH}/{org_schema}/{api_name}/" - - @staticmethod - def construct_status_endpoint(api_endpoint: str, execution_id: str) -> str: - """Construct a complete status endpoint URL by appending the - execution_id as a query parameter. - - Args: - api_endpoint (str): The base API endpoint. - execution_id (str): The execution ID to be included as - a query parameter. - - Returns: - str: The complete status endpoint URL. - """ - query_parameters = urlencode({"execution_id": execution_id}) - complete_endpoint = f"/{api_endpoint}?{query_parameters}" - return complete_endpoint - - @staticmethod - def get_deployment_by_api_name( - api_name: str, - ) -> Optional[APIDeployment]: - """Get and return the APIDeployment object by api_name.""" - try: - api: APIDeployment = APIDeployment.objects.get(api_name=api_name) - return api - except APIDeployment.DoesNotExist: - return None - - @staticmethod - def create_api_key(serializer: Serializer) -> APIKey: - """To make API key for an API. - - Args: - serializer (Serializer): Request serializer - - Raises: - ApiKeyCreateException: Exception - """ - api_deployment: APIDeployment = serializer.instance - try: - api_key: APIKey = KeyHelper.create_api_key(api_deployment) - return api_key - except Exception as error: - logger.error(f"Error while creating API key error: {str(error)}") - api_deployment.delete() - logger.info("Deleted the deployment instance") - raise ApiKeyCreateException() - - @classmethod - def execute_workflow( - cls, - organization_name: str, - api: APIDeployment, - file_objs: list[UploadedFile], - timeout: int, - include_metadata: bool = False, - use_file_history: bool = False, - ) -> ReturnDict: - """Execute workflow by api. - - Args: - organization_name (str): organization name - api (APIDeployment): api model object - file_obj (UploadedFile): input file - use_file_history (bool): Use FileHistory table to return results on already - processed files. Defaults to False - - Returns: - ReturnDict: execution status/ result - """ - workflow_id = api.workflow.id - pipeline_id = api.id - execution_id = str(uuid.uuid4()) - - hash_values_of_files = SourceConnector.add_input_file_to_api_storage( - workflow_id=workflow_id, - execution_id=execution_id, - file_objs=file_objs, - use_file_history=use_file_history, - ) - try: - result = WorkflowHelper.execute_workflow_async( - workflow_id=workflow_id, - pipeline_id=pipeline_id, - hash_values_of_files=hash_values_of_files, - timeout=timeout, - execution_id=execution_id, - queue=CeleryQueue.CELERY_API_DEPLOYMENTS, - use_file_history=use_file_history, - ) - result.status_api = DeploymentHelper.construct_status_endpoint( - api_endpoint=api.api_endpoint, execution_id=execution_id - ) - if include_metadata: - result.remove_result_metadata_keys(keys_to_remove=["highlight_data"]) - else: - result.remove_result_metadata_keys() - except Exception as error: - DestinationConnector.delete_api_storage_dir( - workflow_id=workflow_id, execution_id=execution_id - ) - result = ExecutionResponse( - workflow_id=workflow_id, - execution_id=execution_id, - execution_status=ExecutionStatus.ERROR.value, - error=str(error), - ) - return APIExecutionResponseSerializer(result).data - - @staticmethod - def get_execution_status(execution_id: str) -> ExecutionResponse: - """Current status of api execution. - - Args: - execution_id (str): execution id - - Returns: - ReturnDict: status/result of execution - """ - execution_response: ExecutionResponse = WorkflowHelper.get_status_of_async_task( - execution_id=execution_id - ) - return execution_response diff --git a/backend/api/exceptions.py b/backend/api/exceptions.py deleted file mode 100644 index f2e4c415e..000000000 --- a/backend/api/exceptions.py +++ /dev/null @@ -1,62 +0,0 @@ -from typing import Optional - -from rest_framework.exceptions import APIException - - -class NotFoundException(APIException): - status_code = 404 - default_detail = "The requested resource was not found." - - -class PathVariablesNotFound(NotFoundException): - default_detail = "Path variable must be provided." - - -class MandatoryWorkflowId(APIException): - status_code = 400 - default_detail = "Workflow ID is mandatory" - - -class ApiKeyCreateException(APIException): - status_code = 500 - default_detail = "Exception while create API key" - - -class Forbidden(APIException): - status_code = 403 - default_detail = ( - "User is forbidden from performing this action. Please contact admin" - ) - - -class APINotFound(NotFoundException): - default_detail = "API not found" - - -class InvalidAPIRequest(APIException): - status_code = 400 - default_detail = "Bad request" - - -class InactiveAPI(NotFoundException): - default_detail = "API not found or Inactive" - - -class UnauthorizedKey(APIException): - status_code = 401 - default_detail = "Unauthorized" - - -class NoActiveAPIKeyError(APIException): - status_code = 409 - default_detail = "No active API keys configured for this deployment" - - def __init__( - self, - detail: Optional[str] = None, - code: Optional[str] = None, - deployment_name: str = "this deployment", - ): - if detail is None: - detail = f"No active API keys configured for {deployment_name}" - super().__init__(detail, code) diff --git a/backend/api/key_helper.py b/backend/api/key_helper.py deleted file mode 100644 index 3db67d9c8..000000000 --- a/backend/api/key_helper.py +++ /dev/null @@ -1,80 +0,0 @@ -import logging -from typing import Union - -from api.exceptions import UnauthorizedKey -from api.models import APIDeployment, APIKey -from api.serializers import APIKeySerializer -from pipeline.models import Pipeline -from workflow_manager.workflow.workflow_helper import WorkflowHelper - -logger = logging.getLogger(__name__) - - -class KeyHelper: - @staticmethod - def validate_api_key( - api_key: str, instance: Union[APIDeployment, Pipeline] - ) -> None: - """Validate api key. - - Args: - api_key (str): api key from request - instance (Union[APIDeployment, Pipeline]): api or pipeline instance - - Raises: - UnauthorizedKey: if not valid - """ - try: - api_key_instance: APIKey = APIKey.objects.get(api_key=api_key) - if not KeyHelper.has_access(api_key_instance, instance): - raise UnauthorizedKey() - except APIKey.DoesNotExist: - raise UnauthorizedKey() - - @staticmethod - def list_api_keys_of_api(api_instance: APIDeployment) -> list[APIKey]: - api_keys: list[APIKey] = APIKey.objects.filter(api=api_instance).all() - return api_keys - - @staticmethod - def list_api_keys_of_pipeline(pipeline_instance: Pipeline) -> list[APIKey]: - api_keys: list[APIKey] = APIKey.objects.filter(pipeline=pipeline_instance).all() - return api_keys - - @staticmethod - def has_access(api_key: APIKey, instance: Union[APIDeployment, Pipeline]) -> bool: - """Check if the provided API key has access to the specified API - instance. - - Args: - api_key (APIKey): api key associated with the instance - instance (Union[APIDeployment, Pipeline]): api or pipeline instance - - Returns: - bool: True if allowed to execute, False otherwise - """ - if not api_key.is_active: - return False - if isinstance(instance, APIDeployment): - return api_key.api == instance - if isinstance(instance, Pipeline): - return api_key.pipeline == instance - return False - - @staticmethod - def validate_workflow_exists(workflow_id: str) -> None: - """Validate that the specified workflow_id exists in the Workflow - model.""" - WorkflowHelper.get_workflow_by_id(workflow_id) - - @staticmethod - def create_api_key(deployment: Union[APIDeployment, Pipeline]) -> APIKey: - """Create an APIKey entity using the data from the provided - APIDeployment or Pipeline instance.""" - api_key_serializer = APIKeySerializer( - data=deployment.api_key_data, - context={"deployment": deployment}, - ) - api_key_serializer.is_valid(raise_exception=True) - api_key: APIKey = api_key_serializer.save() - return api_key diff --git a/backend/api/migrations/0001_initial.py b/backend/api/migrations/0001_initial.py deleted file mode 100644 index 8ee7d8211..000000000 --- a/backend/api/migrations/0001_initial.py +++ /dev/null @@ -1,185 +0,0 @@ -# Generated by Django 4.2.1 on 2024-01-23 11:18 - -import uuid - -import django.db.models.deletion -from django.conf import settings -from django.db import migrations, models - - -class Migration(migrations.Migration): - initial = True - - dependencies = [ - migrations.swappable_dependency(settings.AUTH_USER_MODEL), - ("workflow", "0001_initial"), - ] - - operations = [ - migrations.CreateModel( - name="APIDeployment", - fields=[ - ("created_at", models.DateTimeField(auto_now_add=True)), - ("modified_at", models.DateTimeField(auto_now=True)), - ( - "id", - models.UUIDField( - default=uuid.uuid4, - editable=False, - primary_key=True, - serialize=False, - ), - ), - ( - "display_name", - models.CharField( - db_comment="User-given display name for the API.", - default="default api", - max_length=30, - unique=True, - ), - ), - ( - "description", - models.CharField( - blank=True, - db_comment="User-given description for the API.", - default="", - max_length=255, - ), - ), - ( - "is_active", - models.BooleanField( - db_comment="Flag indicating whether the API is active or not.", - default=True, - ), - ), - ( - "api_endpoint", - models.CharField( - db_comment="URL endpoint for the API deployment.", - editable=False, - max_length=255, - unique=True, - ), - ), - ( - "api_name", - models.CharField( - db_comment="Short name for the API deployment.", - default="default", - max_length=30, - unique=True, - ), - ), - ( - "created_by", - models.ForeignKey( - blank=True, - editable=False, - null=True, - on_delete=django.db.models.deletion.SET_NULL, - related_name="api_created_by", - to=settings.AUTH_USER_MODEL, - ), - ), - ( - "modified_by", - models.ForeignKey( - blank=True, - editable=False, - null=True, - on_delete=django.db.models.deletion.SET_NULL, - related_name="api_modified_by", - to=settings.AUTH_USER_MODEL, - ), - ), - ( - "workflow", - models.ForeignKey( - db_comment="Foreign key reference to the Workflow model.", - on_delete=django.db.models.deletion.CASCADE, - to="workflow.workflow", - ), - ), - ], - options={ - "abstract": False, - }, - ), - migrations.CreateModel( - name="APIKey", - fields=[ - ("created_at", models.DateTimeField(auto_now_add=True)), - ("modified_at", models.DateTimeField(auto_now=True)), - ( - "id", - models.UUIDField( - db_comment="Unique identifier for the API key.", - default=uuid.uuid4, - editable=False, - primary_key=True, - serialize=False, - ), - ), - ( - "api_key", - models.UUIDField( - db_comment="Actual key UUID.", - default=uuid.uuid4, - editable=False, - unique=True, - ), - ), - ( - "description", - models.CharField( - db_comment="Description of the API key.", - max_length=255, - null=True, - ), - ), - ( - "is_active", - models.BooleanField( - db_comment="Flag indicating whether the API key is active or not.", - default=True, - ), - ), - ( - "api", - models.ForeignKey( - db_comment="Foreign key reference to the APIDeployment model.", - on_delete=django.db.models.deletion.CASCADE, - to="api.apideployment", - ), - ), - ( - "created_by", - models.ForeignKey( - blank=True, - editable=False, - null=True, - on_delete=django.db.models.deletion.SET_NULL, - related_name="api_key_created_by", - to=settings.AUTH_USER_MODEL, - ), - ), - ( - "modified_by", - models.ForeignKey( - blank=True, - editable=False, - null=True, - on_delete=django.db.models.deletion.SET_NULL, - related_name="api_key_modified_by", - to=settings.AUTH_USER_MODEL, - ), - ), - ], - options={ - "abstract": False, - }, - ), - ] diff --git a/backend/api/migrations/0002_apikey_pipeline_alter_apikey_api.py b/backend/api/migrations/0002_apikey_pipeline_alter_apikey_api.py deleted file mode 100644 index 7f5dc61e2..000000000 --- a/backend/api/migrations/0002_apikey_pipeline_alter_apikey_api.py +++ /dev/null @@ -1,37 +0,0 @@ -# Generated by Django 4.2.1 on 2024-08-05 09:36 - -import django.db.models.deletion -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ("pipeline", "0002_alter_pipeline_last_run_status"), - ("api", "0001_initial"), - ] - - operations = [ - migrations.AddField( - model_name="apikey", - name="pipeline", - field=models.ForeignKey( - blank=True, - db_comment="Foreign key reference to the Pipeline model.", - null=True, - on_delete=django.db.models.deletion.CASCADE, - to="pipeline.pipeline", - ), - ), - migrations.AlterField( - model_name="apikey", - name="api", - field=models.ForeignKey( - blank=True, - db_comment="Foreign key reference to the APIDeployment model.", - null=True, - on_delete=django.db.models.deletion.CASCADE, - to="api.apideployment", - ), - ), - ] diff --git a/backend/api/migrations/__init__.py b/backend/api/migrations/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/backend/api/models.py b/backend/api/models.py deleted file mode 100644 index a6f417fef..000000000 --- a/backend/api/models.py +++ /dev/null @@ -1,155 +0,0 @@ -import uuid -from typing import Any - -from account.models import User -from api.constants import ApiExecution -from django.db import connection, models -from pipeline.models import Pipeline -from utils.models.base_model import BaseModel -from workflow_manager.workflow.models.workflow import Workflow - -API_NAME_MAX_LENGTH = 30 -DESCRIPTION_MAX_LENGTH = 255 -API_ENDPOINT_MAX_LENGTH = 255 - - -class APIDeployment(BaseModel): - id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) - display_name = models.CharField( - max_length=API_NAME_MAX_LENGTH, - unique=True, - default="default api", - db_comment="User-given display name for the API.", - ) - description = models.CharField( - max_length=DESCRIPTION_MAX_LENGTH, - blank=True, - default="", - db_comment="User-given description for the API.", - ) - workflow = models.ForeignKey( - Workflow, - on_delete=models.CASCADE, - db_comment="Foreign key reference to the Workflow model.", - ) - is_active = models.BooleanField( - default=True, - db_comment="Flag indicating whether the API is active or not.", - ) - # TODO: Implement dynamic generation of API endpoints for API deployments - # instead of persisting them in the database. - api_endpoint = models.CharField( - max_length=API_ENDPOINT_MAX_LENGTH, - unique=True, - editable=False, - db_comment="URL endpoint for the API deployment.", - ) - api_name = models.CharField( - max_length=API_NAME_MAX_LENGTH, - unique=True, - default="default", - db_comment="Short name for the API deployment.", - ) - created_by = models.ForeignKey( - User, - on_delete=models.SET_NULL, - related_name="api_created_by", - null=True, - blank=True, - editable=False, - ) - modified_by = models.ForeignKey( - User, - on_delete=models.SET_NULL, - related_name="api_modified_by", - null=True, - blank=True, - editable=False, - ) - - @property - def api_key_data(self): - return {"api": self.id, "description": f"API Key for {self.api_name}"} - - def __str__(self) -> str: - return f"{self.id} - {self.display_name}" - - def save(self, *args: Any, **kwargs: Any) -> None: - """Save hook to update api_endpoint. - - Custom save hook for updating the 'api_endpoint' based on - 'api_name'. If the instance is being updated, it checks for - changes in 'api_name' and adjusts 'api_endpoint' - accordingly. If the instance is new, 'api_endpoint' is set - based on 'api_name' and the current database schema. - """ - if self.pk is not None: - try: - original = APIDeployment.objects.get(pk=self.pk) - if original.api_name != self.api_name: - org_schema = connection.tenant.schema_name - self.api_endpoint = ( - f"{ApiExecution.PATH}/{org_schema}/{self.api_name}/" - ) - except APIDeployment.DoesNotExist: - org_schema = connection.tenant.schema_name - - self.api_endpoint = f"{ApiExecution.PATH}/{org_schema}/{self.api_name}/" - super().save(*args, **kwargs) - - -class APIKey(BaseModel): - id = models.UUIDField( - primary_key=True, - editable=False, - default=uuid.uuid4, - db_comment="Unique identifier for the API key.", - ) - api_key = models.UUIDField( - default=uuid.uuid4, - editable=False, - unique=True, - db_comment="Actual key UUID.", - ) - api = models.ForeignKey( - APIDeployment, - on_delete=models.CASCADE, - null=True, - blank=True, - db_comment="Foreign key reference to the APIDeployment model.", - ) - pipeline = models.ForeignKey( - Pipeline, - on_delete=models.CASCADE, - null=True, - blank=True, - db_comment="Foreign key reference to the Pipeline model.", - ) - description = models.CharField( - max_length=DESCRIPTION_MAX_LENGTH, - null=True, - db_comment="Description of the API key.", - ) - is_active = models.BooleanField( - default=True, - db_comment="Flag indicating whether the API key is active or not.", - ) - created_by = models.ForeignKey( - User, - on_delete=models.SET_NULL, - related_name="api_key_created_by", - null=True, - blank=True, - editable=False, - ) - modified_by = models.ForeignKey( - User, - on_delete=models.SET_NULL, - related_name="api_key_modified_by", - null=True, - blank=True, - editable=False, - ) - - def __str__(self) -> str: - return f"{self.api.api_name} - {self.id} - {self.api_key}" diff --git a/backend/api/notification.py b/backend/api/notification.py deleted file mode 100644 index ebf3cc390..000000000 --- a/backend/api/notification.py +++ /dev/null @@ -1,37 +0,0 @@ -import logging - -from api.models import APIDeployment -from notification.helper import NotificationHelper -from notification.models import Notification -from pipeline.dto import PipelineStatusPayload -from workflow_manager.workflow.models.execution import WorkflowExecution - -logger = logging.getLogger(__name__) - - -class APINotification: - def __init__( - self, api: APIDeployment, workflow_execution: WorkflowExecution - ) -> None: - self.notifications = Notification.objects.filter(api=api, is_active=True) - self.api = api - self.workflow_execution = workflow_execution - - def send(self): - if not self.notifications.count(): - logger.info(f"No notifications found for api {self.api}") - return - logger.info(f"Sending api status notification for api {self.api}") - - payload_dto = PipelineStatusPayload( - type="API", - pipeline_id=self.api.id, - pipeline_name=self.api.api_name, - status=self.workflow_execution.status, - execution_id=self.workflow_execution.id, - error_message=self.workflow_execution.error_message, - ) - - NotificationHelper.send_notification( - notifications=self.notifications, payload=payload_dto.to_dict() - ) diff --git a/backend/api/postman_collection/__init__.py b/backend/api/postman_collection/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/backend/api/postman_collection/constants.py b/backend/api/postman_collection/constants.py deleted file mode 100644 index 2374c478b..000000000 --- a/backend/api/postman_collection/constants.py +++ /dev/null @@ -1,7 +0,0 @@ -class CollectionKey: - POSTMAN_COLLECTION_V210 = "https://schema.getpostman.com/json/collection/v2.1.0/collection.json" # noqa: E501 - EXECUTE_API_KEY = "Process document" - EXECUTE_PIPELINE_API_KEY = "Process pipeline" - STATUS_API_KEY = "Execution status" - STATUS_EXEC_ID_DEFAULT = "REPLACE_WITH_EXECUTION_ID" - AUTH_QUERY_PARAM_DEFAULT = "REPLACE_WITH_API_KEY" diff --git a/backend/api/postman_collection/dto.py b/backend/api/postman_collection/dto.py deleted file mode 100644 index a10a440d5..000000000 --- a/backend/api/postman_collection/dto.py +++ /dev/null @@ -1,239 +0,0 @@ -from abc import ABC, abstractmethod -from dataclasses import asdict, dataclass, field -from typing import Any, Optional, Union -from urllib.parse import urlencode, urljoin - -from api.constants import ApiExecution -from api.models import APIDeployment -from api.postman_collection.constants import CollectionKey -from django.conf import settings -from pipeline.models import Pipeline -from utils.request import HTTPMethod - - -@dataclass -class HeaderItem: - key: str - value: str - - -@dataclass -class FormDataItem: - key: str - type: str - src: Optional[str] = None - value: Optional[str] = None - - def __post_init__(self) -> None: - if self.type == "file": - if self.src is None: - raise ValueError("src must be provided for type 'file'") - elif self.type == "text": - if self.value is None: - raise ValueError("value must be provided for type 'text'") - else: - raise ValueError(f"Unsupported type for form data: {self.type}") - - -@dataclass -class BodyItem: - formdata: list[FormDataItem] - mode: str = "formdata" - - -@dataclass -class RequestItem: - method: HTTPMethod - url: str - header: list[HeaderItem] - body: Optional[BodyItem] = None - - -@dataclass -class PostmanItem: - name: str - request: RequestItem - - -@dataclass -class PostmanInfo: - name: str = "Unstract's API deployment" - schema: str = CollectionKey.POSTMAN_COLLECTION_V210 - description: str = "Contains APIs meant for using the deployed Unstract API" - - -class APIBase(ABC): - - # @abstractmethod - # def get_name(self) -> str: - # pass - - # @abstractmethod - # def get_description(self) -> str: - # pass - - @abstractmethod - def get_form_data_items(self) -> list[FormDataItem]: - pass - - @abstractmethod - def get_api_endpoint(self) -> str: - pass - - @abstractmethod - def get_postman_items(self) -> list[PostmanItem]: - pass - - @abstractmethod - def get_api_key(self) -> str: - pass - - def get_execute_body(self) -> BodyItem: - form_data_items = self.get_form_data_items() - return BodyItem(formdata=form_data_items) - - def get_create_api_request(self) -> RequestItem: - header_list = [ - HeaderItem(key="Authorization", value=f"Bearer {self.get_api_key()}") - ] - abs_api_endpoint = urljoin(settings.WEB_APP_ORIGIN_URL, self.get_api_endpoint()) - return RequestItem( - method=HTTPMethod.POST, - header=header_list, - body=self.get_execute_body(), - url=abs_api_endpoint, - ) - - -@dataclass -class APIDeploymentDto(APIBase): - display_name: str - description: str - api_endpoint: str - api_key: str - - def get_postman_info(self) -> PostmanInfo: - return PostmanInfo(name=self.display_name, description=self.description) - - def get_form_data_items(self) -> list[FormDataItem]: - return [ - FormDataItem( - key=ApiExecution.FILES_FORM_DATA, type="file", src="/path_to_file" - ), - FormDataItem( - key=ApiExecution.TIMEOUT_FORM_DATA, - type="text", - value=ApiExecution.MAXIMUM_TIMEOUT_IN_SEC, - ), - FormDataItem(key=ApiExecution.INCLUDE_METADATA, type="text", value="False"), - ] - - def get_api_key(self) -> str: - return self.api_key - - def get_api_endpoint(self) -> str: - return self.api_endpoint - - def _get_status_api_request(self) -> RequestItem: - header_list = [HeaderItem(key="Authorization", value=f"Bearer {self.api_key}")] - status_query_param = { - "execution_id": CollectionKey.STATUS_EXEC_ID_DEFAULT, - ApiExecution.INCLUDE_METADATA: "False", - } - status_query_str = urlencode(status_query_param) - abs_api_endpoint = urljoin(settings.WEB_APP_ORIGIN_URL, self.api_endpoint) - status_url = urljoin(abs_api_endpoint, "?" + status_query_str) - return RequestItem(method=HTTPMethod.GET, header=header_list, url=status_url) - - def get_postman_items(self) -> list[PostmanItem]: - postman_item_list = [ - PostmanItem( - name=CollectionKey.EXECUTE_API_KEY, - request=self.get_create_api_request(), - ), - PostmanItem( - name=CollectionKey.STATUS_API_KEY, - request=self._get_status_api_request(), - ), - ] - return postman_item_list - - -@dataclass -class PipelineDto(APIBase): - pipeline_name: str - api_endpoint: str - api_key: str - - def get_postman_info(self) -> PostmanInfo: - return PostmanInfo(name=self.pipeline_name, description="") - - def get_form_data_items(self) -> list[FormDataItem]: - return [] - - def get_api_endpoint(self) -> str: - return self.api_endpoint - - def get_api_key(self) -> str: - return self.api_key - - def get_postman_items(self) -> list[PostmanItem]: - postman_item_list = [ - PostmanItem( - name=CollectionKey.EXECUTE_PIPELINE_API_KEY, - request=self.get_create_api_request(), - ) - ] - return postman_item_list - - -@dataclass -class PostmanCollection: - info: PostmanInfo - item: list[PostmanItem] = field(default_factory=list) - - @classmethod - def create( - cls, - instance: Union[APIDeployment, Pipeline], - api_key: str = CollectionKey.AUTH_QUERY_PARAM_DEFAULT, - ) -> "PostmanCollection": - """Creates a PostmanCollection instance. - - This instance can help represent Postman collections (v2 format) that - can be used to easily invoke workflows deployed as APIs - - Args: - instance (APIDeployment): API deployment to generate collection for - api_key (str, optional): Active API key used to authenticate requests for - deployed APIs. Defaults to CollectionKey.AUTH_QUERY_PARAM_DEFAULT. - - Returns: - PostmanCollection: Instance representing PostmanCollection - """ - data_object: APIBase - if isinstance(instance, APIDeployment): - data_object = APIDeploymentDto( - display_name=instance.display_name, - description=instance.description, - api_endpoint=instance.api_endpoint, - api_key=api_key, - ) - elif isinstance(instance, Pipeline): - data_object = PipelineDto( - pipeline_name=instance.pipeline_name, - api_endpoint=instance.api_endpoint, - api_key=api_key, - ) - postman_info: PostmanInfo = data_object.get_postman_info() - postman_item_list = data_object.get_postman_items() - return cls(info=postman_info, item=postman_item_list) - - def to_dict(self) -> dict[str, Any]: - """Convert PostmanCollection instance to a dict. - - Returns: - dict[str, Any]: PostmanCollection as a dict - """ - collection_dict = asdict(self) - return collection_dict diff --git a/backend/api/serializers.py b/backend/api/serializers.py deleted file mode 100644 index 35216e85f..000000000 --- a/backend/api/serializers.py +++ /dev/null @@ -1,148 +0,0 @@ -from collections import OrderedDict -from typing import Any, Union - -from api.constants import ApiExecution -from api.models import APIDeployment, APIKey -from django.core.validators import RegexValidator -from pipeline.models import Pipeline -from rest_framework.serializers import ( - BooleanField, - CharField, - IntegerField, - JSONField, - ModelSerializer, - Serializer, - ValidationError, -) - -from backend.serializers import AuditSerializer - - -class APIDeploymentSerializer(AuditSerializer): - class Meta: - model = APIDeployment - fields = "__all__" - - def validate_api_name(self, value: str) -> str: - api_name_validator = RegexValidator( - regex=r"^[a-zA-Z0-9_-]+$", - message="Only letters, numbers, hyphen and \ - underscores are allowed.", - code="invalid_api_name", - ) - api_name_validator(value) - return value - - -class APIKeySerializer(AuditSerializer): - class Meta: - model = APIKey - fields = "__all__" - - def validate(self, data): - api = data.get("api") - pipeline = data.get("pipeline") - - if api and pipeline: - raise ValidationError( - "Only one of `api` or `pipeline` should be set, not both." - ) - elif not api and not pipeline: - raise ValidationError("At least one of `api` or `pipeline` must be set.") - - return data - - def to_representation(self, instance: APIKey) -> OrderedDict[str, Any]: - """Override the to_representation method to include additional - context.""" - deployment: Union[APIDeployment, Pipeline] = self.context.get("deployment") - representation: OrderedDict[str, Any] = super().to_representation(instance) - - if deployment: - # Handle APIDeployment and Pipeline separately - if isinstance(deployment, APIDeployment): - representation["api"] = deployment.id - representation["pipeline"] = None - representation["description"] = f"API Key for {deployment.api_name}" - elif isinstance(deployment, Pipeline): - representation["api"] = None - representation["pipeline"] = deployment.id - representation["description"] = ( - f"API Key for {deployment.pipeline_name}" - ) - else: - raise ValueError( - "Context must be an instance of APIDeployment or Pipeline" - ) - - representation["is_active"] = True - - return representation - - -class ExecutionRequestSerializer(Serializer): - """Execution request serializer. - - Attributes: - timeout (int): Timeout for the API deployment, maximum value can be 300s. - If -1 it corresponds to async execution. Defaults to -1 - include_metadata (bool): Flag to include metadata in API response - use_file_history (bool): Flag to use FileHistory to save and retrieve - responses quickly. This is undocumented to the user and can be - helpful for demos. - """ - - timeout = IntegerField( - min_value=-1, max_value=ApiExecution.MAXIMUM_TIMEOUT_IN_SEC, default=-1 - ) - include_metadata = BooleanField(default=False) - use_file_history = BooleanField(default=False) - - -class APIDeploymentListSerializer(ModelSerializer): - workflow_name = CharField(source="workflow.workflow_name", read_only=True) - - class Meta: - model = APIDeployment - fields = [ - "id", - "workflow", - "workflow_name", - "display_name", - "description", - "is_active", - "api_endpoint", - "api_name", - "created_by", - ] - - -class APIKeyListSerializer(ModelSerializer): - class Meta: - model = APIKey - fields = [ - "id", - "created_at", - "modified_at", - "api_key", - "is_active", - "description", - "api", - ] - - -class DeploymentResponseSerializer(Serializer): - is_active = CharField() - id = CharField() - api_key = CharField() - api_endpoint = CharField() - display_name = CharField() - description = CharField() - api_name = CharField() - - -class APIExecutionResponseSerializer(Serializer): - execution_status = CharField() - status_api = CharField() - error = CharField() - result = JSONField() diff --git a/backend/api/tests.py b/backend/api/tests.py deleted file mode 100644 index a39b155ac..000000000 --- a/backend/api/tests.py +++ /dev/null @@ -1 +0,0 @@ -# Create your tests here. diff --git a/backend/api/urls.py b/backend/api/urls.py deleted file mode 100644 index 689bdd63b..000000000 --- a/backend/api/urls.py +++ /dev/null @@ -1,66 +0,0 @@ -from api.api_deployment_views import APIDeploymentViewSet, DeploymentExecution -from api.api_key_views import APIKeyViewSet -from django.urls import path, re_path -from rest_framework.urlpatterns import format_suffix_patterns - -deployment = APIDeploymentViewSet.as_view( - { - "get": APIDeploymentViewSet.list.__name__, - "post": APIDeploymentViewSet.create.__name__, - } -) -deployment_details = APIDeploymentViewSet.as_view( - { - "get": APIDeploymentViewSet.retrieve.__name__, - "put": APIDeploymentViewSet.update.__name__, - "patch": APIDeploymentViewSet.partial_update.__name__, - "delete": APIDeploymentViewSet.destroy.__name__, - } -) -download_postman_collection = APIDeploymentViewSet.as_view( - { - "get": APIDeploymentViewSet.download_postman_collection.__name__, - } -) - -execute = DeploymentExecution.as_view() - -key_details = APIKeyViewSet.as_view( - { - "get": APIKeyViewSet.retrieve.__name__, - "put": APIKeyViewSet.update.__name__, - "delete": APIKeyViewSet.destroy.__name__, - } -) -api_key = APIKeyViewSet.as_view( - { - "get": APIKeyViewSet.api_keys.__name__, - "post": APIKeyViewSet.create.__name__, - } -) - -urlpatterns = format_suffix_patterns( - [ - path("deployment/", deployment, name="api_deployment"), - path( - "deployment//", - deployment_details, - name="api_deployment_details", - ), - path( - "postman_collection//", - download_postman_collection, - name="download_postman_collection", - ), - re_path( - r"^api/(?P[\w-]+)/(?P[\w-]+)/?$", - execute, - name="api_deployment_execution", - ), - path("keys//", key_details, name="key_details"), - path("keys/api//", api_key, name="api_key_api"), - path("keys/api/", api_key, name="api_keys_api"), - path("keys/pipeline//", api_key, name="api_key_pipeline"), - path("keys/pipeline/", api_key, name="api_keys_pipeline"), - ] -) diff --git a/backend/api/utils.py b/backend/api/utils.py deleted file mode 100644 index 727fe480d..000000000 --- a/backend/api/utils.py +++ /dev/null @@ -1,45 +0,0 @@ -from typing import Optional - -from api.models import APIDeployment -from api.notification import APINotification -from workflow_manager.workflow.models.execution import WorkflowExecution - - -class APIDeploymentUtils: - @staticmethod - def get_api_by_id(api_id: str) -> Optional[APIDeployment]: - """Retrieves an APIDeployment instance by its unique ID. - - Args: - api_id (str): The unique identifier of the APIDeployment to retrieve. - - Returns: - Optional[APIDeployment]: The APIDeployment instance if found, - otherwise None. - """ - try: - api_deployment: APIDeployment = APIDeployment.objects.get(pk=api_id) - return api_deployment - except APIDeployment.DoesNotExist: - return None - - @staticmethod - def send_notification( - api: APIDeployment, workflow_execution: WorkflowExecution - ) -> None: - """Sends a notification for the specified API deployment and workflow - execution. - - Args: - api (APIDeployment): The APIDeployment instance for which the - notification is being sent. - workflow_execution (WorkflowExecution): The WorkflowExecution instance - related to the notification. - - Returns: - None - """ - api_notification = APINotification( - api=api, workflow_execution=workflow_execution - ) - api_notification.send() diff --git a/backend/connector/__init__.py b/backend/connector/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/backend/connector/admin.py b/backend/connector/admin.py deleted file mode 100644 index 3750fb4c5..000000000 --- a/backend/connector/admin.py +++ /dev/null @@ -1,5 +0,0 @@ -from django.contrib import admin - -from .models import ConnectorInstance - -admin.site.register(ConnectorInstance) diff --git a/backend/connector/apps.py b/backend/connector/apps.py deleted file mode 100644 index b9908c2e6..000000000 --- a/backend/connector/apps.py +++ /dev/null @@ -1,5 +0,0 @@ -from django.apps import AppConfig - - -class ConnectorConfig(AppConfig): - name = "connector" diff --git a/backend/connector/connector_instance_helper.py b/backend/connector/connector_instance_helper.py deleted file mode 100644 index 3af7eff41..000000000 --- a/backend/connector/connector_instance_helper.py +++ /dev/null @@ -1,330 +0,0 @@ -import logging -from typing import Any, Optional - -from account.models import User -from connector.constants import ConnectorInstanceConstant -from connector.models import ConnectorInstance -from connector.unstract_account import UnstractAccount -from django.conf import settings -from django.db import connection -from workflow_manager.workflow.models.workflow import Workflow - -from unstract.connectors.filesystems.ucs import UnstractCloudStorage -from unstract.connectors.filesystems.ucs.constants import UCSKey - -logger = logging.getLogger(__name__) - - -class ConnectorInstanceHelper: - @staticmethod - def create_default_gcs_connector(workflow: Workflow, user: User) -> None: - """Method to create default storage connector. - - Args: - org_id (str) - workflow (Workflow) - user (User) - """ - org_schema = connection.tenant.schema_name - if not user.project_storage_created: - logger.info("Creating default storage") - account = UnstractAccount(org_schema, user.email) - account.provision_s3_storage() - account.upload_sample_files() - user.project_storage_created = True - user.save() - logger.info("default storage created successfully.") - - logger.info("Adding connectors to Unstract") - connector_name = ConnectorInstanceConstant.USER_STORAGE - gcs_id = UnstractCloudStorage.get_id() - bucket_name = settings.UNSTRACT_FREE_STORAGE_BUCKET_NAME - base_path = f"{bucket_name}/{org_schema}/{user.email}" - - connector_metadata = { - UCSKey.KEY: settings.GOOGLE_STORAGE_ACCESS_KEY_ID, - UCSKey.SECRET: settings.GOOGLE_STORAGE_SECRET_ACCESS_KEY, - UCSKey.ENDPOINT_URL: settings.GOOGLE_STORAGE_BASE_URL, - } - connector_metadata__input = { - **connector_metadata, - UCSKey.PATH: base_path + "/input", - } - connector_metadata__output = { - **connector_metadata, - UCSKey.PATH: base_path + "/output", - } - ConnectorInstance.objects.create( - connector_name=connector_name, - workflow=workflow, - created_by=user, - connector_id=gcs_id, - connector_metadata=connector_metadata__input, - connector_type=ConnectorInstance.ConnectorType.INPUT, - connector_mode=ConnectorInstance.ConnectorMode.FILE_SYSTEM, - ) - ConnectorInstance.objects.create( - connector_name=connector_name, - workflow=workflow, - created_by=user, - connector_id=gcs_id, - connector_metadata=connector_metadata__output, - connector_type=ConnectorInstance.ConnectorType.OUTPUT, - connector_mode=ConnectorInstance.ConnectorMode.FILE_SYSTEM, - ) - logger.info("Connectors added successfully.") - - @staticmethod - def get_connector_instances_by_workflow( - workflow_id: str, - connector_type: tuple[str, str], - connector_mode: Optional[tuple[int, str]] = None, - values: Optional[list[str]] = None, - connector_name: Optional[str] = None, - ) -> list[ConnectorInstance]: - """Method to get connector instances by workflow. - - Args: - workflow_id (str) - connector_type (tuple[str, str]): Specifies input/output - connector_mode (Optional[tuple[int, str]], optional): - Specifies database/file - values (Optional[list[str]], optional): Defaults to None. - connector_name (Optional[str], optional): Defaults to None. - - Returns: - list[ConnectorInstance] - """ - logger.info(f"Setting connector mode to {connector_mode}") - filter_params: dict[str, Any] = { - "workflow": workflow_id, - "connector_type": connector_type, - } - if connector_mode is not None: - filter_params["connector_mode"] = connector_mode - if connector_name is not None: - filter_params["connector_name"] = connector_name - - connector_instances = ConnectorInstance.objects.filter(**filter_params).all() - logger.debug(f"Retrieved connector instance values {connector_instances}") - if values is not None: - filtered_connector_instances = connector_instances.values(*values) - logger.info( - f"Returning filtered \ - connector instance value {filtered_connector_instances}" - ) - return list(filtered_connector_instances) - logger.info(f"Returning connector instances {connector_instances}") - return list(connector_instances) - - @staticmethod - def get_connector_instance_by_workflow( - workflow_id: str, - connector_type: tuple[str, str], - connector_mode: Optional[tuple[int, str]] = None, - connector_name: Optional[str] = None, - ) -> Optional[ConnectorInstance]: - """Get one connector instance. - - Use this method if the connector instance is unique for \ - filter_params - Args: - workflow_id (str): _description_ - connector_type (tuple[str, str]): Specifies input/output - connector_mode (Optional[tuple[int, str]], optional). - Specifies database/filesystem - values (Optional[list[str]], optional). - connector_name (Optional[str], optional). - - Returns: - list[ConnectorInstance]: _description_ - """ - logger.info("Fetching connector instance by workflow") - filter_params: dict[str, Any] = { - "workflow": workflow_id, - "connector_type": connector_type, - } - if connector_mode is not None: - filter_params["connector_mode"] = connector_mode - if connector_name is not None: - filter_params["connector_name"] = connector_name - - try: - connector_instance: ConnectorInstance = ConnectorInstance.objects.filter( - **filter_params - ).first() - except Exception as exc: - logger.error(f"Error occured while fetching connector instances {exc}") - raise exc - - return connector_instance - - @staticmethod - def get_input_connector_instance_by_name_for_workflow( - workflow_id: str, - connector_name: str, - ) -> Optional[ConnectorInstance]: - """Method to get Input connector instance name from the workflow. - - Args: - workflow_id (str) - connector_name (str) - - Returns: - Optional[ConnectorInstance] - """ - return ConnectorInstanceHelper.get_connector_instance_by_workflow( - workflow_id=workflow_id, - connector_type=ConnectorInstance.ConnectorType.INPUT, - connector_name=connector_name, - ) - - @staticmethod - def get_output_connector_instance_by_name_for_workflow( - workflow_id: str, - connector_name: str, - ) -> Optional[ConnectorInstance]: - """Method to get output connector name by Workflow. - - Args: - workflow_id (str) - connector_name (str) - - Returns: - Optional[ConnectorInstance] - """ - return ConnectorInstanceHelper.get_connector_instance_by_workflow( - workflow_id=workflow_id, - connector_type=ConnectorInstance.ConnectorType.OUTPUT, - connector_name=connector_name, - ) - - @staticmethod - def get_input_connector_instances_by_workflow( - workflow_id: str, - ) -> list[ConnectorInstance]: - """Method to get connector instances by workflow. - - Args: - workflow_id (str) - - Returns: - list[ConnectorInstance] - """ - return ConnectorInstanceHelper.get_connector_instances_by_workflow( - workflow_id, ConnectorInstance.ConnectorType.INPUT - ) - - @staticmethod - def get_output_connector_instances_by_workflow( - workflow_id: str, - ) -> list[ConnectorInstance]: - """Method to get output connector instances by workflow. - - Args: - workflow_id (str): _description_ - - Returns: - list[ConnectorInstance]: _description_ - """ - return ConnectorInstanceHelper.get_connector_instances_by_workflow( - workflow_id, ConnectorInstance.ConnectorType.OUTPUT - ) - - @staticmethod - def get_file_system_input_connector_instances_by_workflow( - workflow_id: str, values: Optional[list[str]] = None - ) -> list[ConnectorInstance]: - """Method to fetch file system connector by workflow. - - Args: - workflow_id (str): - values (Optional[list[str]], optional) - - Returns: - list[ConnectorInstance] - """ - return ConnectorInstanceHelper.get_connector_instances_by_workflow( - workflow_id, - ConnectorInstance.ConnectorType.INPUT, - ConnectorInstance.ConnectorMode.FILE_SYSTEM, - values, - ) - - @staticmethod - def get_file_system_output_connector_instances_by_workflow( - workflow_id: str, values: Optional[list[str]] = None - ) -> list[ConnectorInstance]: - """Method to get file system output connector by workflow. - - Args: - workflow_id (str) - values (Optional[list[str]], optional) - - Returns: - list[ConnectorInstance] - """ - return ConnectorInstanceHelper.get_connector_instances_by_workflow( - workflow_id, - ConnectorInstance.ConnectorType.OUTPUT, - ConnectorInstance.ConnectorMode.FILE_SYSTEM, - values, - ) - - @staticmethod - def get_database_input_connector_instances_by_workflow( - workflow_id: str, values: Optional[list[str]] = None - ) -> list[ConnectorInstance]: - """Method to fetch input database connectors by workflow. - - Args: - workflow_id (str) - values (Optional[list[str]], optional) - - Returns: - list[ConnectorInstance] - """ - return ConnectorInstanceHelper.get_connector_instances_by_workflow( - workflow_id, - ConnectorInstance.ConnectorType.INPUT, - ConnectorInstance.ConnectorMode.DATABASE, - values, - ) - - @staticmethod - def get_database_output_connector_instances_by_workflow( - workflow_id: str, values: Optional[list[str]] = None - ) -> list[ConnectorInstance]: - """Method to fetch output database connectors by workflow. - - Args: - workflow_id (str) - values (Optional[list[str]], optional) - - Returns: - list[ConnectorInstance] - """ - return ConnectorInstanceHelper.get_connector_instances_by_workflow( - workflow_id, - ConnectorInstance.ConnectorType.OUTPUT, - ConnectorInstance.ConnectorMode.DATABASE, - values, - ) - - @staticmethod - def get_input_output_connector_instances_by_workflow( - workflow_id: str, - ) -> list[ConnectorInstance]: - """Method to fetch input and output connectors by workflow. - - Args: - workflow_id (str) - - Returns: - list[ConnectorInstance] - """ - filter_params: dict[str, Any] = { - "workflow": workflow_id, - } - connector_instances = ConnectorInstance.objects.filter(**filter_params).all() - return list(connector_instances) diff --git a/backend/connector/constants.py b/backend/connector/constants.py deleted file mode 100644 index 5b9234e6e..000000000 --- a/backend/connector/constants.py +++ /dev/null @@ -1,17 +0,0 @@ -class ConnectorInstanceKey: - CONNECTOR_ID = "connector_id" - CONNECTOR_NAME = "connector_name" - CONNECTOR_TYPE = "connector_type" - CONNECTOR_MODE = "connector_mode" - CONNECTOR_VERSION = "connector_version" - CONNECTOR_AUTH = "connector_auth" - CONNECTOR_METADATA = "connector_metadata" - CONNECTOR_METADATA_B = "connector_metadata_b" - CONNECTOR_EXISTS = ( - "Connector with this configuration already exists in this project." - ) - DUPLICATE_API = "It appears that a duplicate call may have been made." - - -class ConnectorInstanceConstant: - USER_STORAGE = "User Storage" diff --git a/backend/connector/fields.py b/backend/connector/fields.py deleted file mode 100644 index b96b206bb..000000000 --- a/backend/connector/fields.py +++ /dev/null @@ -1,38 +0,0 @@ -import logging -from datetime import datetime - -from connector_auth.constants import SocialAuthConstants -from connector_auth.models import ConnectorAuth -from django.db import models - -logger = logging.getLogger(__name__) - - -class ConnectorAuthJSONField(models.JSONField): - def from_db_value(self, value, expression, connection): # type: ignore - """Overrding default function.""" - metadata = super().from_db_value(value, expression, connection) - provider = metadata.get(SocialAuthConstants.PROVIDER) - uid = metadata.get(SocialAuthConstants.UID) - if provider and uid: - refresh_after_str = metadata.get(SocialAuthConstants.REFRESH_AFTER) - if refresh_after_str: - refresh_after = datetime.strptime( - refresh_after_str, SocialAuthConstants.REFRESH_AFTER_FORMAT - ) - if datetime.now() > refresh_after: - metadata = self._refresh_tokens(provider, uid) - return metadata - - def _refresh_tokens(self, provider: str, uid: str) -> dict[str, str]: - """Retrieves PSA object and refreshes the token if necessary.""" - connector_auth: ConnectorAuth = ConnectorAuth.get_social_auth( - provider=provider, uid=uid - ) - tokens_refreshed = False - if connector_auth: - ( - connector_metadata, - tokens_refreshed, - ) = connector_auth.get_and_refresh_tokens() - return connector_metadata # type: ignore diff --git a/backend/connector/migrations/0001_initial.py b/backend/connector/migrations/0001_initial.py deleted file mode 100644 index 08d18dce2..000000000 --- a/backend/connector/migrations/0001_initial.py +++ /dev/null @@ -1,122 +0,0 @@ -# Generated by Django 4.2.1 on 2024-01-23 11:18 - -import uuid - -import connector.fields -import django.db.models.deletion -from django.conf import settings -from django.db import migrations, models - - -class Migration(migrations.Migration): - initial = True - - dependencies = [ - ("project", "0001_initial"), - migrations.swappable_dependency(settings.AUTH_USER_MODEL), - ("workflow", "0001_initial"), - ("connector_auth", "0001_initial"), - ] - - operations = [ - migrations.CreateModel( - name="ConnectorInstance", - fields=[ - ("created_at", models.DateTimeField(auto_now_add=True)), - ("modified_at", models.DateTimeField(auto_now=True)), - ( - "id", - models.UUIDField( - default=uuid.uuid4, - editable=False, - primary_key=True, - serialize=False, - ), - ), - ("connector_name", models.TextField(max_length=128)), - ("connector_id", models.CharField(default="", max_length=128)), - ( - "connector_metadata", - connector.fields.ConnectorAuthJSONField( - db_column="connector_metadata", default=dict - ), - ), - ( - "connector_version", - models.CharField(default="", max_length=64), - ), - ( - "connector_type", - models.CharField( - choices=[("INPUT", "Input"), ("OUTPUT", "Output")] - ), - ), - ( - "connector_mode", - models.CharField( - choices=[ - (0, "UNKNOWN"), - (1, "FILE_SYSTEM"), - (2, "DATABASE"), - ], - db_comment="0: UNKNOWN, 1: FILE_SYSTEM, 2: DATABASE", - default=0, - ), - ), - ( - "connector_auth", - models.ForeignKey( - blank=True, - null=True, - on_delete=django.db.models.deletion.SET_NULL, - to="connector_auth.connectorauth", - ), - ), - ( - "created_by", - models.ForeignKey( - blank=True, - null=True, - on_delete=django.db.models.deletion.SET_NULL, - related_name="created_connectors", - to=settings.AUTH_USER_MODEL, - ), - ), - ( - "modified_by", - models.ForeignKey( - blank=True, - null=True, - on_delete=django.db.models.deletion.SET_NULL, - related_name="modified_connectors", - to=settings.AUTH_USER_MODEL, - ), - ), - ( - "project", - models.ForeignKey( - blank=True, - null=True, - on_delete=django.db.models.deletion.CASCADE, - related_name="project_connectors", - to="project.project", - ), - ), - ( - "workflow", - models.ForeignKey( - on_delete=django.db.models.deletion.CASCADE, - related_name="workflow_connectors", - to="workflow.workflow", - ), - ), - ], - ), - migrations.AddConstraint( - model_name="connectorinstance", - constraint=models.UniqueConstraint( - fields=("connector_name", "workflow", "connector_type"), - name="unique_connector", - ), - ), - ] diff --git a/backend/connector/migrations/0002_connectorinstance_connector_metadata_b.py b/backend/connector/migrations/0002_connectorinstance_connector_metadata_b.py deleted file mode 100644 index 7c075439d..000000000 --- a/backend/connector/migrations/0002_connectorinstance_connector_metadata_b.py +++ /dev/null @@ -1,39 +0,0 @@ -# Generated by Django 4.2.1 on 2024-02-16 06:50 - -import json -from typing import Any - -from connector.models import ConnectorInstance -from cryptography.fernet import Fernet -from django.conf import settings -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("connector", "0001_initial"), - ] - - def EncryptCredentials(apps: Any, schema_editor: Any) -> None: - encryption_secret: str = settings.ENCRYPTION_KEY - f: Fernet = Fernet(encryption_secret.encode("utf-8")) - queryset = ConnectorInstance.objects.all() - - for obj in queryset: # type: ignore - # Access attributes of the object - - if hasattr(obj, "connector_metadata"): - json_string: str = json.dumps(obj.connector_metadata) - obj.connector_metadata_b = f.encrypt(json_string.encode("utf-8")) - obj.save() - - operations = [ - migrations.AddField( - model_name="connectorinstance", - name="connector_metadata_b", - field=models.BinaryField(null=True), - ), - migrations.RunPython( - EncryptCredentials, reverse_code=migrations.RunPython.noop - ), - ] diff --git a/backend/connector/migrations/0003_alter_connectorinstance_connector_mode.py b/backend/connector/migrations/0003_alter_connectorinstance_connector_mode.py deleted file mode 100644 index 5e5f1d50e..000000000 --- a/backend/connector/migrations/0003_alter_connectorinstance_connector_mode.py +++ /dev/null @@ -1,27 +0,0 @@ -# Generated by Django 4.2.1 on 2024-06-24 12:51 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ("connector", "0002_connectorinstance_connector_metadata_b"), - ] - - operations = [ - migrations.AlterField( - model_name="connectorinstance", - name="connector_mode", - field=models.CharField( - choices=[ - (0, "UNKNOWN"), - (1, "FILE_SYSTEM"), - (2, "DATABASE"), - (3, "APPDEPLOYMENT"), - ], - db_comment="Choices of connectors", - default=0, - ), - ), - ] diff --git a/backend/connector/migrations/0004_alter_connectorinstance_connector_mode.py b/backend/connector/migrations/0004_alter_connectorinstance_connector_mode.py deleted file mode 100644 index 6feec997f..000000000 --- a/backend/connector/migrations/0004_alter_connectorinstance_connector_mode.py +++ /dev/null @@ -1,28 +0,0 @@ -# Generated by Django 4.2.1 on 2024-07-04 05:44 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ("connector", "0003_alter_connectorinstance_connector_mode"), - ] - - operations = [ - migrations.AlterField( - model_name="connectorinstance", - name="connector_mode", - field=models.CharField( - choices=[ - (0, "UNKNOWN"), - (1, "FILE_SYSTEM"), - (2, "DATABASE"), - (3, "APPDEPLOYMENT"), - (4, "MANUAL_REVIEW"), - ], - db_comment="Choices of connectors", - default=0, - ), - ), - ] diff --git a/backend/connector/migrations/__init__.py b/backend/connector/migrations/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/backend/connector/models.py b/backend/connector/models.py deleted file mode 100644 index 958946efe..000000000 --- a/backend/connector/models.py +++ /dev/null @@ -1,131 +0,0 @@ -import json -import uuid -from typing import Any - -from account.models import User -from connector.fields import ConnectorAuthJSONField -from connector_auth.models import ConnectorAuth -from connector_processor.connector_processor import ConnectorProcessor -from connector_processor.constants import ConnectorKeys -from cryptography.fernet import Fernet, InvalidToken -from django.conf import settings -from django.db import models -from project.models import Project -from utils.exceptions import InvalidEncryptionKey -from utils.models.base_model import BaseModel -from workflow_manager.workflow.models import Workflow - -from backend.constants import FieldLengthConstants as FLC - -CONNECTOR_NAME_SIZE = 128 -VERSION_NAME_SIZE = 64 - - -class ConnectorInstance(BaseModel): - class ConnectorType(models.TextChoices): - INPUT = "INPUT", "Input" - OUTPUT = "OUTPUT", "Output" - - class ConnectorMode(models.IntegerChoices): - UNKNOWN = 0, "UNKNOWN" - FILE_SYSTEM = 1, "FILE_SYSTEM" - DATABASE = 2, "DATABASE" - APPDEPLOYMENT = 3, "APPDEPLOYMENT" - MANUAL_REVIEW = 4, "MANUAL_REVIEW" - - id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) - connector_name = models.TextField( - max_length=CONNECTOR_NAME_SIZE, null=False, blank=False - ) - project = models.ForeignKey( - Project, - on_delete=models.CASCADE, - related_name="project_connectors", - null=True, - blank=True, - ) - workflow = models.ForeignKey( - Workflow, - on_delete=models.CASCADE, - related_name="workflow_connectors", - null=False, - blank=False, - ) - connector_id = models.CharField(max_length=FLC.CONNECTOR_ID_LENGTH, default="") - # TODO Required to be removed - connector_metadata = ConnectorAuthJSONField( - db_column="connector_metadata", null=False, blank=False, default=dict - ) - connector_metadata_b = models.BinaryField(null=True) - connector_version = models.CharField(max_length=VERSION_NAME_SIZE, default="") - connector_type = models.CharField(choices=ConnectorType.choices) - connector_auth = models.ForeignKey( - ConnectorAuth, on_delete=models.SET_NULL, null=True, blank=True - ) - connector_mode = models.CharField( - choices=ConnectorMode.choices, - default=ConnectorMode.UNKNOWN, - db_comment="Choices of connectors", - ) - - created_by = models.ForeignKey( - User, - on_delete=models.SET_NULL, - related_name="created_connectors", - null=True, - blank=True, - ) - modified_by = models.ForeignKey( - User, - on_delete=models.SET_NULL, - related_name="modified_connectors", - null=True, - blank=True, - ) - - def get_connector_metadata(self) -> dict[str, str]: - """Gets connector metadata and refreshes the tokens if needed in case - of OAuth.""" - tokens_refreshed = False - if self.connector_auth: - ( - self.connector_metadata, - tokens_refreshed, - ) = self.connector_auth.get_and_refresh_tokens() - if tokens_refreshed: - self.save() - return self.connector_metadata - - @staticmethod - def supportsOAuth(connector_id: str) -> bool: - return bool( - ConnectorProcessor.get_connector_data_with_key( - connector_id, ConnectorKeys.OAUTH - ) - ) - - def __str__(self) -> str: - return ( - f"Connector({self.id}, type{self.connector_type}," - f" workflow: {self.workflow})" - ) - - @property - def metadata(self) -> Any: - try: - encryption_secret: str = settings.ENCRYPTION_KEY - cipher_suite: Fernet = Fernet(encryption_secret.encode("utf-8")) - decrypted_value = cipher_suite.decrypt( - bytes(self.connector_metadata_b).decode("utf-8") - ) - except InvalidToken: - raise InvalidEncryptionKey(entity=InvalidEncryptionKey.Entity.CONNECTOR) - return json.loads(decrypted_value) - - class Meta: - constraints = [ - models.UniqueConstraint( - fields=["connector_name", "workflow", "connector_type"], - name="unique_connector", - ), - ] diff --git a/backend/connector/serializers.py b/backend/connector/serializers.py deleted file mode 100644 index c5ce1054f..000000000 --- a/backend/connector/serializers.py +++ /dev/null @@ -1,82 +0,0 @@ -import json -import logging -from collections import OrderedDict -from typing import Any, Optional - -from connector.constants import ConnectorInstanceKey as CIKey -from connector_auth.models import ConnectorAuth -from connector_auth.pipeline.common import ConnectorAuthHelper -from connector_processor.connector_processor import ConnectorProcessor -from connector_processor.constants import ConnectorKeys -from connector_processor.exceptions import OAuthTimeOut -from cryptography.fernet import Fernet -from django.conf import settings -from utils.serializer_utils import SerializerUtils - -from backend.serializers import AuditSerializer -from unstract.connectors.filesystems.ucs import UnstractCloudStorage - -from .models import ConnectorInstance - -logger = logging.getLogger(__name__) - - -class ConnectorInstanceSerializer(AuditSerializer): - class Meta: - model = ConnectorInstance - fields = "__all__" - - def save(self, **kwargs): # type: ignore - user = self.context.get("request").user or None - connector_id: str = kwargs[CIKey.CONNECTOR_ID] - connector_oauth: Optional[ConnectorAuth] = None - if ( - ConnectorInstance.supportsOAuth(connector_id=connector_id) - and CIKey.CONNECTOR_METADATA in kwargs - ): - try: - connector_oauth = ConnectorAuthHelper.get_or_create_connector_auth( - user=user, # type: ignore - oauth_credentials=kwargs[CIKey.CONNECTOR_METADATA], - ) - kwargs[CIKey.CONNECTOR_AUTH] = connector_oauth - ( - kwargs[CIKey.CONNECTOR_METADATA], - refresh_status, - ) = connector_oauth.get_and_refresh_tokens() - except Exception as exc: - logger.error(f"Error while obtaining ConnectorAuth: {exc}") - raise OAuthTimeOut - - connector_mode = ConnectorProcessor.get_connector_data_with_key( - connector_id, CIKey.CONNECTOR_MODE - ) - kwargs[CIKey.CONNECTOR_MODE] = connector_mode.value - - encryption_secret: str = settings.ENCRYPTION_KEY - f: Fernet = Fernet(encryption_secret.encode("utf-8")) - json_string: str = json.dumps(kwargs.pop(CIKey.CONNECTOR_METADATA)) - if self.validated_data: - self.validated_data.pop(CIKey.CONNECTOR_METADATA) - - kwargs[CIKey.CONNECTOR_METADATA_B] = f.encrypt(json_string.encode("utf-8")) - - instance = super().save(**kwargs) - return instance - - def to_representation(self, instance: ConnectorInstance) -> dict[str, str]: - # to remove the sensitive fields being returned - rep: OrderedDict[str, Any] = super().to_representation(instance) - if instance.connector_id == UnstractCloudStorage.get_id(): - rep[CIKey.CONNECTOR_METADATA] = {} - if SerializerUtils.check_context_for_GET_or_POST(context=self.context): - rep.pop(CIKey.CONNECTOR_AUTH) - # set icon fields for UI - rep[ConnectorKeys.ICON] = ConnectorProcessor.get_connector_data_with_key( - instance.connector_id, ConnectorKeys.ICON - ) - - rep.pop(CIKey.CONNECTOR_METADATA_B) - if instance.connector_metadata_b: - rep[CIKey.CONNECTOR_METADATA] = instance.metadata - return rep diff --git a/backend/connector/tests/conftest.py b/backend/connector/tests/conftest.py deleted file mode 100644 index 89f1715a1..000000000 --- a/backend/connector/tests/conftest.py +++ /dev/null @@ -1,9 +0,0 @@ -import pytest -from django.core.management import call_command - - -@pytest.fixture(scope="session") -def django_db_setup(django_db_blocker): # type: ignore - fixtures = ["./connector/tests/fixtures/fixtures_0001.json"] - with django_db_blocker.unblock(): - call_command("loaddata", *fixtures) diff --git a/backend/connector/tests/connector_tests.py b/backend/connector/tests/connector_tests.py deleted file mode 100644 index cad6512f2..000000000 --- a/backend/connector/tests/connector_tests.py +++ /dev/null @@ -1,332 +0,0 @@ -# mypy: ignore-errors -import pytest -from connector.models import ConnectorInstance -from django.urls import reverse -from rest_framework import status -from rest_framework.test import APITestCase - -pytestmark = pytest.mark.django_db - - -@pytest.mark.connector -class TestConnector(APITestCase): - def test_connector_list(self) -> None: - """Tests to List the connectors.""" - - url = reverse("connectors_v1-list") - response = self.client.get(url) - - self.assertEqual(response.status_code, status.HTTP_200_OK) - - def test_connectors_detail(self) -> None: - """Tests to fetch a connector with given pk.""" - - url = reverse("connectors_v1-detail", kwargs={"pk": 1}) - response = self.client.get(url) - - self.assertEqual(response.status_code, status.HTTP_200_OK) - - def test_connectors_detail_not_found(self) -> None: - """Tests for negative case to fetch non exiting key.""" - - url = reverse("connectors_v1-detail", kwargs={"pk": 768}) - response = self.client.get(url) - - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - - def test_connectors_create(self) -> None: - """Tests to create a new ConnectorInstance.""" - - url = reverse("connectors_v1-list") - data = { - "org": 1, - "project": 1, - "created_by": 2, - "modified_by": 2, - "modified_at": "2023-06-14T05:28:47.759Z", - "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", - "connector_metadata": { - "drive_link": "sample_url", - "sharable_link": True, - }, - } - response = self.client.post(url, data, format="json") - - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - self.assertEqual(ConnectorInstance.objects.count(), 2) - - def test_connectors_create_with_json_list(self) -> None: - """Tests to create a new connector with list included in the json - field.""" - - url = reverse("connectors_v1-list") - data = { - "org": 1, - "project": 1, - "created_by": 2, - "modified_by": 2, - "modified_at": "2023-06-14T05:28:47.759Z", - "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", - "connector_metadata": { - "drive_link": "sample_url", - "sharable_link": True, - "file_name_list": ["a1", "a2"], - }, - } - response = self.client.post(url, data, format="json") - - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - self.assertEqual(ConnectorInstance.objects.count(), 2) - - def test_connectors_create_with_nested_json(self) -> None: - """Tests to create a new connector with json field as nested json.""" - - url = reverse("connectors_v1-list") - data = { - "org": 1, - "project": 1, - "created_by": 2, - "modified_by": 2, - "modified_at": "2023-06-14T05:28:47.759Z", - "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", - "connector_metadata": { - "drive_link": "sample_url", - "sharable_link": True, - "sample_metadata_json": {"key1": "value1", "key2": "value2"}, - }, - } - response = self.client.post(url, data, format="json") - - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - self.assertEqual(ConnectorInstance.objects.count(), 2) - - def test_connectors_create_bad_request(self) -> None: - """Tests for negative case to throw error on a wrong access.""" - - url = reverse("connectors_v1-list") - data = { - "org": 5, - "project": 1, - "created_by": 2, - "modified_by": 2, - "modified_at": "2023-06-14T05:28:47.759Z", - "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", - "connector_metadata": { - "drive_link": "sample_url", - "sharable_link": True, - "sample_metadata_json": {"key1": "value1", "key2": "value2"}, - }, - } - response = self.client.post(url, data, format="json") - - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - def test_connectors_update_json_field(self) -> None: - """Tests to update connector with json field update.""" - - url = reverse("connectors_v1-detail", kwargs={"pk": 1}) - data = { - "org": 1, - "project": 1, - "created_by": 2, - "modified_by": 2, - "modified_at": "2023-06-14T05:28:47.759Z", - "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", - "connector_metadata": { - "drive_link": "new_sample_url", - "sharable_link": True, - "sample_metadata_json": {"key1": "value1", "key2": "value2"}, - }, - } - response = self.client.put(url, data, format="json") - drive_link = response.data["connector_metadata"]["drive_link"] - self.assertEqual(drive_link, "new_sample_url") - - def test_connectors_update(self) -> None: - """Tests to update connector update single field.""" - - url = reverse("connectors_v1-detail", kwargs={"pk": 1}) - data = { - "org": 1, - "project": 1, - "created_by": 1, - "modified_by": 2, - "modified_at": "2023-06-14T05:28:47.759Z", - "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", - "connector_metadata": { - "drive_link": "new_sample_url", - "sharable_link": True, - "sample_metadata_json": {"key1": "value1", "key2": "value2"}, - }, - } - response = self.client.put(url, data, format="json") - modified_by = response.data["modified_by"] - self.assertEqual(modified_by, 2) - self.assertEqual(response.status_code, status.HTTP_200_OK) - - def test_connectors_update_pk(self) -> None: - """Tests the PUT method for 400 error.""" - - url = reverse("connectors_v1-detail", kwargs={"pk": 1}) - data = { - "org": 2, - "project": 1, - "created_by": 2, - "modified_by": 2, - "modified_at": "2023-06-14T05:28:47.759Z", - "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", - "connector_metadata": { - "drive_link": "new_sample_url", - "sharable_link": True, - "sample_metadata_json": {"key1": "value1", "key2": "value2"}, - }, - } - response = self.client.put(url, data, format="json") - - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - def test_connectors_update_json_fields(self) -> None: - """Tests to update ConnectorInstance.""" - - url = reverse("connectors_v1-detail", kwargs={"pk": 1}) - data = { - "org": 1, - "project": 1, - "created_by": 2, - "modified_by": 2, - "modified_at": "2023-06-14T05:28:47.759Z", - "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", - "connector_metadata": { - "drive_link": "new_sample_url", - "sharable_link": True, - "sample_metadata_json": {"key1": "value1", "key2": "value2"}, - }, - } - response = self.client.put(url, data, format="json") - nested_value = response.data["connector_metadata"]["sample_metadata_json"][ - "key1" - ] - - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(nested_value, "value1") - - def test_connectors_update_json_list_fields(self) -> None: - """Tests to update connector to the third second level of json.""" - - url = reverse("connectors_v1-detail", kwargs={"pk": 1}) - data = { - "org": 1, - "project": 1, - "created_by": 2, - "modified_by": 2, - "modified_at": "2023-06-14T05:28:47.759Z", - "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", - "connector_metadata": { - "drive_link": "new_sample_url", - "sharable_link": True, - "sample_metadata_json": {"key1": "value1", "key2": "value2"}, - "file_list": ["a1", "a2", "a3"], - }, - } - response = self.client.put(url, data, format="json") - nested_value = response.data["connector_metadata"]["sample_metadata_json"][ - "key1" - ] - nested_list = response.data["connector_metadata"]["file_list"] - last_val = nested_list.pop() - - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(nested_value, "value1") - self.assertEqual(last_val, "a3") - - # @pytest.mark.xfail(raises=KeyError) - # def test_connectors_update_json_fields_failed(self) -> None: - # """Tests to update connector to the second level of JSON with a wrong - # key.""" - - # url = reverse("connectors_v1-detail", kwargs={"pk": 1}) - # data = { - # "org": 1, - # "project": 1, - # "created_by": 2, - # "modified_by": 2, - # "modified_at": "2023-06-14T05:28:47.759Z", - # "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", - # "connector_metadata": { - # "drive_link": "new_sample_url", - # "sharable_link": True, - # "sample_metadata_json": {"key1": "value1", "key2": "value2"}, - # }, - # } - # response = self.client.put(url, data, format="json") - # nested_value = response.data["connector_metadata"]["sample_metadata_json"][ - # "key00" - # ] - - # @pytest.mark.xfail(raises=KeyError) - # def test_connectors_update_json_nested_failed(self) -> None: - # """Tests to update connector to test a first level of json with a wrong - # key.""" - - # url = reverse("connectors_v1-detail", kwargs={"pk": 1}) - # data = { - # "org": 1, - # "project": 1, - # "created_by": 2, - # "modified_by": 2, - # "modified_at": "2023-06-14T05:28:47.759Z", - # "connector_id": "e3a4512m-efgb-48d5-98a9-3983nd77f", - # "connector_metadata": { - # "drive_link": "new_sample_url", - # "sharable_link": True, - # "sample_metadata_json": {"key1": "value1", "key2": "value2"}, - # }, - # } - # response = self.client.put(url, data, format="json") - # nested_value = response.data["connector_metadata"]["sample_metadata_jsonNew"] - - def test_connectors_update_field(self) -> None: - """Tests the PATCH method.""" - - url = reverse("connectors_v1-detail", kwargs={"pk": 1}) - data = {"connector_id": "e3a4512m-efgb-48d5-98a9-3983ntest"} - response = self.client.patch(url, data, format="json") - self.assertEqual(response.status_code, status.HTTP_200_OK) - connector_id = response.data["connector_id"] - - self.assertEqual( - connector_id, - ConnectorInstance.objects.get(connector_id=connector_id).connector_id, - ) - - def test_connectors_update_json_field_patch(self) -> None: - """Tests the PATCH method.""" - - url = reverse("connectors_v1-detail", kwargs={"pk": 1}) - data = { - "connector_metadata": { - "drive_link": "patch_update_url", - "sharable_link": True, - "sample_metadata_json": { - "key1": "patch_update1", - "key2": "value2", - }, - } - } - - response = self.client.patch(url, data, format="json") - self.assertEqual(response.status_code, status.HTTP_200_OK) - drive_link = response.data["connector_metadata"]["drive_link"] - - self.assertEqual(drive_link, "patch_update_url") - - def test_connectors_delete(self) -> None: - """Tests the DELETE method.""" - - url = reverse("connectors_v1-detail", kwargs={"pk": 1}) - response = self.client.delete(url, format="json") - self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) - - url = reverse("connectors_v1-detail", kwargs={"pk": 1}) - response = self.client.get(url) - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) diff --git a/backend/connector/tests/fixtures/fixtures_0001.json b/backend/connector/tests/fixtures/fixtures_0001.json deleted file mode 100644 index 55b39e6d8..000000000 --- a/backend/connector/tests/fixtures/fixtures_0001.json +++ /dev/null @@ -1,67 +0,0 @@ -[ - { - "model": "account.org", - "pk": 1, - "fields": { - "org_name": "Zipstack", - "created_by": 1, - "modified_by": 1, - "modified_at": "2023-06-14T05:28:47.739Z" - } - }, - { - "model": "account.user", - "pk": 1, - "fields": { - "org": 1, - "email": "johndoe@gmail.com", - "first_name": "John", - "last_name": "Doe", - "is_admin": true, - "created_by": null, - "modified_by": null, - "modified_at": "2023-06-14T05:28:47.744Z" - } - }, - { - "model": "account.user", - "pk": 2, - "fields": { - "org": 1, - "email": "user1@gmail.com", - "first_name": "Ron", - "last_name": "Stone", - "is_admin": false, - "created_by": 1, - "modified_by": 1, - "modified_at": "2023-06-14T05:28:47.750Z" - } - }, - { - "model": "project.project", - "pk": 1, - "fields": { - "org": 1, - "project_name": "Unstract Test", - "created_by": 2, - "modified_by": 2, - "modified_at": "2023-06-14T05:28:47.759Z" - } - }, - { - "model": "connector.connector", - "pk": 1, - "fields": { - "org": 1, - "project": 1, - "created_by": 2, - "modified_by": 2, - "modified_at": "2023-06-14T05:28:47.759Z", - "connector_id": "e38a59b7-efbb-48d5-9da6-3a0cf2d882a0", - "connector_metadata": { - "connector_type": "gdrive", - "auth_type": "oauth" - } - } - } -] diff --git a/backend/connector/unstract_account.py b/backend/connector/unstract_account.py deleted file mode 100644 index e8f496089..000000000 --- a/backend/connector/unstract_account.py +++ /dev/null @@ -1,75 +0,0 @@ -import logging -import os - -import boto3 -from botocore.exceptions import ClientError -from django.conf import settings - -logger = logging.getLogger(__name__) - - -# TODO: UnstractAccount need to be pluggable -class UnstractAccount: - def __init__(self, tenant: str, username: str) -> None: - self.tenant = tenant - self.username = username - - def provision_s3_storage(self) -> None: - access_key = settings.GOOGLE_STORAGE_ACCESS_KEY_ID - secret_key = settings.GOOGLE_STORAGE_SECRET_ACCESS_KEY - bucket_name: str = settings.UNSTRACT_FREE_STORAGE_BUCKET_NAME - - s3 = boto3.client( - "s3", - aws_access_key_id=access_key, - aws_secret_access_key=secret_key, - endpoint_url="https://storage.googleapis.com", - ) - - # Check if folder exists and create if it is not available - account_folder = f"{self.tenant}/{self.username}/input/examples/" - try: - logger.info(f"Checking if folder {account_folder} exists...") - s3.head_object(Bucket=bucket_name, Key=account_folder) - logger.info(f"Folder {account_folder} already exists") - except ClientError as e: - logger.info(f"{bucket_name} Folder {account_folder} does not exist") - if e.response["Error"]["Code"] == "404": - logger.info(f"Folder {account_folder} does not exist. Creating it...") - s3.put_object(Bucket=bucket_name, Key=account_folder) - account_folder_output = f"{self.tenant}/{self.username}/output/" - s3.put_object(Bucket=bucket_name, Key=account_folder_output) - else: - logger.error(f"Error checking folder {account_folder}: {e}") - raise e - - def upload_sample_files(self) -> None: - access_key = settings.GOOGLE_STORAGE_ACCESS_KEY_ID - secret_key = settings.GOOGLE_STORAGE_SECRET_ACCESS_KEY - bucket_name: str = settings.UNSTRACT_FREE_STORAGE_BUCKET_NAME - - s3 = boto3.client( - "s3", - aws_access_key_id=access_key, - aws_secret_access_key=secret_key, - endpoint_url="https://storage.googleapis.com", - ) - - folder = f"{self.tenant}/{self.username}/input/examples/" - - local_path = f"{os.path.dirname(__file__)}/static" - for root, dirs, files in os.walk(local_path): - for file in files: - local_file_path = os.path.join(root, file) - s3_key = os.path.join( - folder, os.path.relpath(local_file_path, local_path) - ) - logger.info( - f"Uploading: {local_file_path} => " f"s3://{bucket_name}/{s3_key}" - ) - try: - s3.upload_file(local_file_path, bucket_name, s3_key) - except ClientError as e: - logger.error(e) - raise e - logger.info(f"Uploaded: {local_file_path}") diff --git a/backend/connector/urls.py b/backend/connector/urls.py deleted file mode 100644 index 424033528..000000000 --- a/backend/connector/urls.py +++ /dev/null @@ -1,21 +0,0 @@ -from django.urls import path -from rest_framework.urlpatterns import format_suffix_patterns - -from .views import ConnectorInstanceViewSet as CIViewSet - -connector_list = CIViewSet.as_view({"get": "list", "post": "create"}) -connector_detail = CIViewSet.as_view( - { - "get": "retrieve", - "put": "update", - "patch": "partial_update", - "delete": "destroy", - } -) - -urlpatterns = format_suffix_patterns( - [ - path("connector/", connector_list, name="connector-list"), - path("connector//", connector_detail, name="connector-detail"), - ] -) diff --git a/backend/connector/views.py b/backend/connector/views.py deleted file mode 100644 index 2443d6bd2..000000000 --- a/backend/connector/views.py +++ /dev/null @@ -1,123 +0,0 @@ -import logging -from typing import Any, Optional - -from account.custom_exceptions import DuplicateData -from connector.constants import ConnectorInstanceKey as CIKey -from connector_auth.constants import ConnectorAuthKey -from connector_auth.exceptions import CacheMissException, MissingParamException -from connector_auth.pipeline.common import ConnectorAuthHelper -from connector_processor.exceptions import OAuthTimeOut -from django.db import IntegrityError -from django.db.models import QuerySet -from rest_framework import status, viewsets -from rest_framework.response import Response -from rest_framework.versioning import URLPathVersioning -from utils.filtering import FilterHelper - -from backend.constants import RequestKey - -from .models import ConnectorInstance -from .serializers import ConnectorInstanceSerializer - -logger = logging.getLogger(__name__) - - -class ConnectorInstanceViewSet(viewsets.ModelViewSet): - versioning_class = URLPathVersioning - queryset = ConnectorInstance.objects.all() - serializer_class = ConnectorInstanceSerializer - - def get_queryset(self) -> Optional[QuerySet]: - filter_args = FilterHelper.build_filter_args( - self.request, - RequestKey.WORKFLOW, - RequestKey.CREATED_BY, - CIKey.CONNECTOR_TYPE, - CIKey.CONNECTOR_MODE, - ) - if filter_args: - queryset = ConnectorInstance.objects.filter(**filter_args) - else: - queryset = ConnectorInstance.objects.all() - return queryset - - def _get_connector_metadata(self, connector_id: str) -> Optional[dict[str, str]]: - """Gets connector metadata for the ConnectorInstance. - - For non oauth based - obtains from request - For oauth based - obtains from cache - - Raises: - e: MissingParamException, CacheMissException - - Returns: - dict[str, str]: Connector creds dict to connect with - """ - connector_metadata = None - if ConnectorInstance.supportsOAuth(connector_id=connector_id): - logger.info(f"Fetching oauth data for {connector_id}") - oauth_key = self.request.query_params.get(ConnectorAuthKey.OAUTH_KEY) - if oauth_key is None: - logger.error("OAuth key missing") - raise MissingParamException(param=ConnectorAuthKey.OAUTH_KEY) - connector_metadata = ConnectorAuthHelper.get_oauth_creds_from_cache( - cache_key=oauth_key, delete_key=True - ) - if connector_metadata is None: - raise CacheMissException( - f"Couldn't find credentials for {oauth_key} from cache" - ) - else: - connector_metadata = self.request.data.get(CIKey.CONNECTOR_METADATA) - return connector_metadata - - def perform_update(self, serializer: ConnectorInstanceSerializer) -> None: - connector_metadata = None - connector_id = self.request.data.get( - CIKey.CONNECTOR_ID, serializer.instance.connector_id - ) - try: - connector_metadata = self._get_connector_metadata(connector_id) - except Exception: - # Suppress here to not shout during partial updates - pass - # Take metadata from instance itself since update - # is performed on other fields of ConnectorInstance - if connector_metadata is None: - connector_metadata = serializer.instance.connector_metadata - serializer.save( - connector_id=connector_id, - connector_metadata=connector_metadata, - modified_by=self.request.user, - ) # type: ignore - - def perform_create(self, serializer: ConnectorInstanceSerializer) -> None: - connector_metadata = None - connector_id = self.request.data.get(CIKey.CONNECTOR_ID) - try: - connector_metadata = self._get_connector_metadata(connector_id=connector_id) - except Exception as exc: - logger.error(f"Error while obtaining ConnectorAuth: {exc}") - raise OAuthTimeOut - serializer.save( - connector_id=connector_id, - connector_metadata=connector_metadata, - created_by=self.request.user, - modified_by=self.request.user, - ) # type: ignore - - def create(self, request: Any) -> Response: - # Overriding default exception behavior - serializer = self.get_serializer(data=request.data) - serializer.is_valid(raise_exception=True) - try: - self.perform_create(serializer) - except IntegrityError: - raise DuplicateData( - f"{CIKey.CONNECTOR_EXISTS}, \ - {CIKey.DUPLICATE_API}" - ) - headers = self.get_success_headers(serializer.data) - return Response( - serializer.data, status=status.HTTP_201_CREATED, headers=headers - ) diff --git a/backend/connector_auth/__init__.py b/backend/connector_auth/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/backend/connector_auth/admin.py b/backend/connector_auth/admin.py deleted file mode 100644 index 014dfec3e..000000000 --- a/backend/connector_auth/admin.py +++ /dev/null @@ -1,5 +0,0 @@ -from django.contrib import admin - -from .models import ConnectorAuth - -admin.site.register(ConnectorAuth) diff --git a/backend/connector_auth/apps.py b/backend/connector_auth/apps.py deleted file mode 100644 index 6925d9844..000000000 --- a/backend/connector_auth/apps.py +++ /dev/null @@ -1,6 +0,0 @@ -from django.apps import AppConfig - - -class ConnectorAuthConfig(AppConfig): - default_auto_field = "django.db.models.BigAutoField" - name = "connector_auth" diff --git a/backend/connector_auth/constants.py b/backend/connector_auth/constants.py deleted file mode 100644 index 886968d87..000000000 --- a/backend/connector_auth/constants.py +++ /dev/null @@ -1,18 +0,0 @@ -class ConnectorAuthKey: - OAUTH_KEY = "oauth-key" - - -class SocialAuthConstants: - UID = "uid" - PROVIDER = "provider" - ACCESS_TOKEN = "access_token" - REFRESH_TOKEN = "refresh_token" - TOKEN_TYPE = "token_type" - AUTH_TIME = "auth_time" - EXPIRES = "expires" - - REFRESH_AFTER_FORMAT = "%d/%m/%Y %H:%M:%S" - REFRESH_AFTER = "refresh_after" # Timestamp to refresh tokens after - - GOOGLE_OAUTH = "google-oauth2" - GOOGLE_TOKEN_EXPIRY_FORMAT = "%d/%m/%Y %H:%M:%S" diff --git a/backend/connector_auth/exceptions.py b/backend/connector_auth/exceptions.py deleted file mode 100644 index 603bcc8d9..000000000 --- a/backend/connector_auth/exceptions.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import Optional - -from rest_framework.exceptions import APIException - - -class CacheMissException(APIException): - status_code = 404 - default_detail = "Key doesn't exist." - - -class EnrichConnectorMetadataException(APIException): - status_code = 500 - default_detail = "Connector metadata could not be enriched" - - -class MissingParamException(APIException): - status_code = 400 - default_detail = "Bad request, missing parameter." - - def __init__( - self, - code: Optional[str] = None, - param: Optional[str] = None, - ) -> None: - detail = f"Bad request, missing parameter: {param}" - super().__init__(detail, code) - - -class KeyNotConfigured(APIException): - status_code = 500 - default_detail = "Key is not configured correctly" diff --git a/backend/connector_auth/migrations/0001_initial.py b/backend/connector_auth/migrations/0001_initial.py deleted file mode 100644 index 19a702bec..000000000 --- a/backend/connector_auth/migrations/0001_initial.py +++ /dev/null @@ -1,54 +0,0 @@ -# Generated by Django 4.2.1 on 2024-01-23 11:18 - -import uuid - -import django.db.models.deletion -import social_django.storage -from django.conf import settings -from django.db import migrations, models - - -class Migration(migrations.Migration): - initial = True - - dependencies = [ - migrations.swappable_dependency(settings.AUTH_USER_MODEL), - ] - - operations = [ - migrations.CreateModel( - name="ConnectorAuth", - fields=[ - ("provider", models.CharField(max_length=32)), - ("uid", models.CharField(db_index=True, max_length=255)), - ("extra_data", models.JSONField(default=dict)), - ("created", models.DateTimeField(auto_now_add=True)), - ("modified", models.DateTimeField(auto_now=True)), - ( - "id", - models.UUIDField( - default=uuid.uuid4, - editable=False, - primary_key=True, - serialize=False, - ), - ), - ( - "user", - models.ForeignKey( - null=True, - on_delete=django.db.models.deletion.SET_NULL, - related_name="connector_auth", - to=settings.AUTH_USER_MODEL, - ), - ), - ], - bases=(models.Model, social_django.storage.DjangoUserMixin), - ), - migrations.AddConstraint( - model_name="connectorauth", - constraint=models.UniqueConstraint( - fields=("provider", "uid"), name="unique_provider_uid" - ), - ), - ] diff --git a/backend/connector_auth/migrations/__init__.py b/backend/connector_auth/migrations/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/backend/connector_auth/models.py b/backend/connector_auth/models.py deleted file mode 100644 index 653400055..000000000 --- a/backend/connector_auth/models.py +++ /dev/null @@ -1,138 +0,0 @@ -import logging -import uuid -from typing import Any - -from account.models import User -from connector_auth.constants import SocialAuthConstants -from connector_auth.pipeline.google import GoogleAuthHelper -from django.db import models -from django.db.models.query import QuerySet -from rest_framework.request import Request -from social_django.fields import JSONField -from social_django.models import AbstractUserSocialAuth, DjangoStorage -from social_django.strategy import DjangoStrategy - -logger = logging.getLogger(__name__) - - -class ConnectorAuthManager(models.Manager): - def get_queryset(self) -> QuerySet: - queryset = super().get_queryset() - # TODO PAN-83: Decrypt here - # for obj in queryset: - # logger.info(f"Decrypting extra_data: {obj.extra_data}") - - return queryset - - -class ConnectorAuth(AbstractUserSocialAuth): - """Social Auth association model, stores tokens. - The relation with `account.User` is only for the library to work - and should be NOT be used to access the secrets. - Use the following static methods instead - ``` - @classmethod - def get_social_auth(cls, provider, id): - - @classmethod - def create_social_auth(cls, user, uid, provider): - ``` - """ - - id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) - user = models.ForeignKey( - User, - related_name="connector_auth", - on_delete=models.SET_NULL, - null=True, - ) - - def __str__(self) -> str: - return f"ConnectorAuth(provider: {self.provider}, uid: {self.uid})" - - def save(self, *args: Any, **kwargs: Any) -> Any: - # TODO PAN-83: Encrypt here - # logger.info(f"Encrypting extra_data: {self.extra_data}") - return super().save(*args, **kwargs) - - def set_extra_data(self, extra_data=None): # type: ignore - ConnectorAuth.check_credential_format(extra_data) - if extra_data[SocialAuthConstants.PROVIDER] == SocialAuthConstants.GOOGLE_OAUTH: - extra_data = GoogleAuthHelper.enrich_connector_metadata(extra_data) - return super().set_extra_data(extra_data) - - def refresh_token(self, strategy, *args, **kwargs): # type: ignore - """Override of Python Social Auth (PSA)'s refresh_token functionality - to store uid, provider.""" - token = self.extra_data.get("refresh_token") or self.extra_data.get( - "access_token" - ) - backend = self.get_backend_instance(strategy) - if token and backend and hasattr(backend, "refresh_token"): - response = backend.refresh_token(token, *args, **kwargs) - extra_data = backend.extra_data(self, self.uid, response, self.extra_data) - extra_data[SocialAuthConstants.PROVIDER] = backend.name - extra_data[SocialAuthConstants.UID] = self.uid - if self.set_extra_data(extra_data): # type: ignore - self.save() - - def get_and_refresh_tokens(self, request: Request = None) -> tuple[JSONField, bool]: - """Uses Social Auth's ability to refresh tokens if necessary. - - Returns: - Tuple[JSONField, bool]: JSONField of connector metadata - and flag indicating if tokens were refreshed - """ - # To avoid circular dependency error on import - from social_django.utils import load_strategy - - refreshed_token = False - strategy: DjangoStrategy = load_strategy(request=request) - existing_access_token = self.access_token - new_access_token = self.get_access_token(strategy) - if new_access_token != existing_access_token: - refreshed_token = True - related_connector_instances = self.connectorinstance_set.all() - for connector_instance in related_connector_instances: - connector_instance.connector_metadata = self.extra_data - connector_instance.save() - logger.info( - f"Refreshed access token for connector {connector_instance.id}, " - f"provider: {self.provider}, uid: {self.uid}" - ) - - return self.extra_data, refreshed_token - - @staticmethod - def check_credential_format( - oauth_credentials: dict[str, str], raise_exception: bool = True - ) -> bool: - if ( - SocialAuthConstants.PROVIDER in oauth_credentials - and SocialAuthConstants.UID in oauth_credentials - ): - return True - else: - if raise_exception: - raise ValueError( - "Auth credential should have provider, uid and connector guid" - ) - return False - - objects = ConnectorAuthManager() - - class Meta: - app_label = "connector_auth" - constraints = [ - models.UniqueConstraint( - fields=[ - "provider", - "uid", - ], - name="unique_provider_uid", - ), - ] - - -class ConnectorDjangoStorage(DjangoStorage): - user = ConnectorAuth diff --git a/backend/connector_auth/pipeline/common.py b/backend/connector_auth/pipeline/common.py deleted file mode 100644 index 3d82cbbf5..000000000 --- a/backend/connector_auth/pipeline/common.py +++ /dev/null @@ -1,111 +0,0 @@ -import logging -from typing import Any, Optional - -from account.models import User -from connector_auth.constants import ConnectorAuthKey, SocialAuthConstants -from connector_auth.models import ConnectorAuth -from connector_auth.pipeline.google import GoogleAuthHelper -from django.conf import settings -from django.core.cache import cache -from rest_framework.exceptions import PermissionDenied -from social_core.backends.oauth import BaseOAuth2 - -logger = logging.getLogger(__name__) - - -def check_user_exists(backend: BaseOAuth2, user: User, **kwargs: Any) -> dict[str, str]: - """Checks if user is authenticated (will be handled in auth middleware, - present as a fail safe) - - Args: - user (account.User): User model - - Raises: - PermissionDenied: Unauthorized user - - Returns: - dict: Carrying response details for auth pipeline - """ - if not user: - raise PermissionDenied(backend) - return {**kwargs} - - -def cache_oauth_creds( - backend: BaseOAuth2, - details: dict[str, str], - response: dict[str, str], - uid: str, - user: User, - *args: Any, - **kwargs: Any, -) -> dict[str, str]: - """Used to cache the extra data JSON in redis against a key. - - This contains the access and refresh token along with details - regarding expiry, uid (unique ID given by provider) and provider. - """ - cache_key = kwargs.get("cache_key") or backend.strategy.session_get( - settings.SOCIAL_AUTH_FIELDS_STORED_IN_SESSION[0], - ConnectorAuthKey.OAUTH_KEY, - ) - extra_data = backend.extra_data(user, uid, response, details, *args, **kwargs) - extra_data[SocialAuthConstants.PROVIDER] = backend.name - extra_data[SocialAuthConstants.UID] = uid - - if backend.name == SocialAuthConstants.GOOGLE_OAUTH: - extra_data = GoogleAuthHelper.enrich_connector_metadata(extra_data) - - cache.set( - cache_key, - extra_data, - int(settings.SOCIAL_AUTH_EXTRA_DATA_EXPIRATION_TIME_IN_SECOND), - ) - return {**kwargs} - - -class ConnectorAuthHelper: - @staticmethod - def get_oauth_creds_from_cache( - cache_key: str, delete_key: bool = True - ) -> Optional[dict[str, str]]: - """Retrieves oauth credentials from the cache. - - Args: - cache_key (str): Key to obtain credentials from - - Returns: - Optional[dict[str,str]]: Returns credentials. None if it doesn't exist - """ - oauth_creds: dict[str, str] = cache.get(cache_key) - if delete_key: - cache.delete(cache_key) - return oauth_creds - - @staticmethod - def get_or_create_connector_auth( - oauth_credentials: dict[str, str], user: User = None # type: ignore - ) -> ConnectorAuth: - """Gets or creates a ConnectorAuth object. - - Args: - user (User): Used while creation, can be removed if not required - oauth_credentials (dict[str,str]): Needs to have provider and uid - - Returns: - ConnectorAuth: Object for the respective provider/uid - """ - ConnectorAuth.check_credential_format(oauth_credentials) - provider = oauth_credentials[SocialAuthConstants.PROVIDER] - uid = oauth_credentials[SocialAuthConstants.UID] - connector_oauth: ConnectorAuth = ConnectorAuth.get_social_auth( - provider=provider, uid=uid - ) - if not connector_oauth: - connector_oauth = ConnectorAuth.create_social_auth( - user, uid=uid, provider=provider - ) - - # TODO: Remove User's related manager access to ConnectorAuth - connector_oauth.set_extra_data(oauth_credentials) # type: ignore - return connector_oauth diff --git a/backend/connector_auth/pipeline/google.py b/backend/connector_auth/pipeline/google.py deleted file mode 100644 index 71a56fb33..000000000 --- a/backend/connector_auth/pipeline/google.py +++ /dev/null @@ -1,33 +0,0 @@ -from datetime import datetime, timedelta - -from connector_auth.constants import SocialAuthConstants as AuthConstants -from connector_auth.exceptions import EnrichConnectorMetadataException -from connector_processor.constants import ConnectorKeys - -from unstract.connectors.filesystems.google_drive.constants import GDriveConstants - - -class GoogleAuthHelper: - @staticmethod - def enrich_connector_metadata(kwargs: dict[str, str]) -> dict[str, str]: - token_expiry: datetime = datetime.now() - auth_time = kwargs.get(AuthConstants.AUTH_TIME) - expires = kwargs.get(AuthConstants.EXPIRES) - if auth_time and expires: - reference = datetime.utcfromtimestamp(float(auth_time)) - token_expiry = reference + timedelta(seconds=float(expires)) - else: - raise EnrichConnectorMetadataException - # Used by GDrive FS, apart from ACCESS_TOKEN and REFRESH_TOKEN - kwargs[GDriveConstants.TOKEN_EXPIRY] = token_expiry.strftime( - AuthConstants.GOOGLE_TOKEN_EXPIRY_FORMAT - ) - - # Used by Unstract - kwargs[ConnectorKeys.PATH] = ( - GDriveConstants.ROOT_PREFIX - ) # Acts as a prefix for all paths - kwargs[AuthConstants.REFRESH_AFTER] = token_expiry.strftime( - AuthConstants.REFRESH_AFTER_FORMAT - ) - return kwargs diff --git a/backend/connector_auth/urls.py b/backend/connector_auth/urls.py deleted file mode 100644 index 55337ad20..000000000 --- a/backend/connector_auth/urls.py +++ /dev/null @@ -1,21 +0,0 @@ -from django.urls import include, path, re_path -from rest_framework.urlpatterns import format_suffix_patterns - -from .views import ConnectorAuthViewSet - -connector_auth_cache = ConnectorAuthViewSet.as_view( - { - "get": "cache_key", - } -) - -urlpatterns = format_suffix_patterns( - [ - path("oauth/", include("social_django.urls", namespace="social")), - re_path( - "^oauth/cache-key/(?P.+)$", - connector_auth_cache, - name="connector-cache", - ), - ] -) diff --git a/backend/connector_auth/views.py b/backend/connector_auth/views.py deleted file mode 100644 index 5e097e75e..000000000 --- a/backend/connector_auth/views.py +++ /dev/null @@ -1,46 +0,0 @@ -import logging -import uuid - -from connector_auth.constants import SocialAuthConstants -from connector_auth.exceptions import KeyNotConfigured -from django.conf import settings -from rest_framework import status, viewsets -from rest_framework.request import Request -from rest_framework.response import Response -from rest_framework.versioning import URLPathVersioning -from utils.user_session import UserSessionUtils - -logger = logging.getLogger(__name__) - - -class ConnectorAuthViewSet(viewsets.ViewSet): - """Contains methods for Connector related authentication.""" - - versioning_class = URLPathVersioning - - def cache_key( - self: "ConnectorAuthViewSet", request: Request, backend: str - ) -> Response: - if backend == SocialAuthConstants.GOOGLE_OAUTH and ( - settings.SOCIAL_AUTH_GOOGLE_OAUTH2_KEY is None - or settings.SOCIAL_AUTH_GOOGLE_OAUTH2_SECRET is None - ): - msg = ( - f"Keys not configured for {backend}, add env vars " - f"`GOOGLE_OAUTH2_KEY` and `GOOGLE_OAUTH2_SECRET`." - ) - logger.warn(msg) - raise KeyNotConfigured( - msg - + "\nRefer https://developers.google.com/identity/protocols/oauth2#1.-obtain-oauth-2.0-credentials-from-the-dynamic_data.setvar.console_name-." # noqa - ) - - random = str(uuid.uuid4()) - user_id = request.user.user_id - org_id = UserSessionUtils.get_organization_id(request) - cache_key = f"oauth:{org_id}|{user_id}|{backend}|{random}" - logger.info(f"Generated cache key: {cache_key}") - return Response( - status=status.HTTP_200_OK, - data={"cache_key": f"{cache_key}"}, - ) diff --git a/backend/notification/__init__.py b/backend/notification/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/backend/notification/apps.py b/backend/notification/apps.py deleted file mode 100644 index 24a58ddc6..000000000 --- a/backend/notification/apps.py +++ /dev/null @@ -1,6 +0,0 @@ -from django.apps import AppConfig - - -class NotificationConfig(AppConfig): - default_auto_field = "django.db.models.BigAutoField" - name = "notification" diff --git a/backend/notification/constants.py b/backend/notification/constants.py deleted file mode 100644 index c106997d3..000000000 --- a/backend/notification/constants.py +++ /dev/null @@ -1,5 +0,0 @@ -class NotificationUrlConstant: - """Constants for Notification Urls.""" - - PIPELINE_UID = "pipeline_uuid" - API_UID = "api_uuid" diff --git a/backend/notification/enums.py b/backend/notification/enums.py deleted file mode 100644 index 991b08cac..000000000 --- a/backend/notification/enums.py +++ /dev/null @@ -1,38 +0,0 @@ -from enum import Enum - - -class NotificationType(Enum): - WEBHOOK = "WEBHOOK" - # Add other notification types as needed - # Example EMAIL = 'EMAIL' - - def get_valid_platforms(self): - if self == NotificationType.WEBHOOK: - return [PlatformType.SLACK.value, PlatformType.API.value] - return [] - - @classmethod - def choices(cls): - return [(e.value, e.name.replace("_", " ").capitalize()) for e in cls] - - -class AuthorizationType(Enum): - BEARER = "BEARER" - API_KEY = "API_KEY" - CUSTOM_HEADER = "CUSTOM_HEADER" - NONE = "NONE" - - @classmethod - def choices(cls): - return [(e.value, e.name.replace("_", " ").capitalize()) for e in cls] - - -class PlatformType(Enum): - SLACK = "SLACK" - API = "API" - # Add other platforms as needed - # Example TEAMS = 'TEAMS' - - @classmethod - def choices(cls): - return [(e.value, e.name.replace("_", " ").capitalize()) for e in cls] diff --git a/backend/notification/helper.py b/backend/notification/helper.py deleted file mode 100644 index 71bd25992..000000000 --- a/backend/notification/helper.py +++ /dev/null @@ -1,48 +0,0 @@ -import logging -from typing import Any - -from notification.enums import NotificationType, PlatformType -from notification.models import Notification -from notification.provider.notification_provider import NotificationProvider -from notification.provider.registry import get_notification_provider - -logger = logging.getLogger(__name__) - - -class NotificationHelper: - @classmethod - def send_notification(cls, notifications: list[Notification], payload: Any) -> None: - """Send notification Sends notifications using the appropriate provider - based on the notification type and platform. - - This method iterates through a list of `Notification` objects, determines the - appropriate notification provider based on the notification's type and - platform, and sends the notification with the provided payload. If an error - occurs due to an invalid notification type or platform, it logs the error. - - Args: - notifications (list[Notification]): A list of `Notification` instances to - be processed and sent. - payload (Any): The data to be sent with the notification. This can be any - format expected by the provider - - Returns: - None - """ - for notification in notifications: - notification_type = NotificationType(notification.notification_type) - platform_type = PlatformType(notification.platform) - try: - notification_provider = get_notification_provider( - notification_type, platform_type - ) - notifier: NotificationProvider = notification_provider( - notification=notification, payload=payload - ) - notifier.send() - logger.info(f"Sending notification to {notification}") - except ValueError as e: - logger.error( - f"Error in notification type {notification_type} and platform " - f"{platform_type} for notification {notification}: {e}" - ) diff --git a/backend/notification/migrations/0001_initial.py b/backend/notification/migrations/0001_initial.py deleted file mode 100644 index b2c27a737..000000000 --- a/backend/notification/migrations/0001_initial.py +++ /dev/null @@ -1,122 +0,0 @@ -# Generated by Django 4.2.1 on 2024-08-07 08:48 - -import uuid - -import django.db.models.deletion -from django.db import migrations, models - - -class Migration(migrations.Migration): - - initial = True - - dependencies = [ - ("pipeline", "0002_alter_pipeline_last_run_status"), - ("api", "0001_initial"), - ] - - operations = [ - migrations.CreateModel( - name="Notification", - fields=[ - ("created_at", models.DateTimeField(auto_now_add=True)), - ("modified_at", models.DateTimeField(auto_now=True)), - ( - "id", - models.UUIDField( - default=uuid.uuid4, - editable=False, - primary_key=True, - serialize=False, - ), - ), - ( - "name", - models.CharField( - db_comment="Name of the notification.", - default="Notification", - max_length=255, - ), - ), - ("url", models.URLField(null=True)), - ( - "authorization_key", - models.CharField(blank=True, max_length=255, null=True), - ), - ( - "authorization_header", - models.CharField(blank=True, max_length=255, null=True), - ), - ( - "authorization_type", - models.CharField( - choices=[ - ("BEARER", "Bearer"), - ("API_KEY", "Api key"), - ("CUSTOM_HEADER", "Custom header"), - ("NONE", "None"), - ], - default="NONE", - max_length=50, - ), - ), - ("max_retries", models.IntegerField(default=0)), - ( - "platform", - models.CharField( - blank=True, - choices=[("SLACK", "Slack"), ("API", "Api")], - max_length=50, - null=True, - ), - ), - ( - "notification_type", - models.CharField( - choices=[("WEBHOOK", "Webhook")], - default="WEBHOOK", - max_length=50, - ), - ), - ( - "is_active", - models.BooleanField( - db_comment="Flag indicating whether the notification is active or not.", - default=True, - ), - ), - ( - "api", - models.ForeignKey( - blank=True, - null=True, - on_delete=django.db.models.deletion.CASCADE, - related_name="notifications", - to="api.apideployment", - ), - ), - ( - "pipeline", - models.ForeignKey( - blank=True, - null=True, - on_delete=django.db.models.deletion.CASCADE, - related_name="notifications", - to="pipeline.pipeline", - ), - ), - ], - ), - migrations.AddConstraint( - model_name="notification", - constraint=models.UniqueConstraint( - fields=("name", "pipeline"), name="unique_name_pipeline" - ), - ), - migrations.AddConstraint( - model_name="notification", - constraint=models.UniqueConstraint( - fields=("name", "api"), name="unique_name_api" - ), - ), - ] diff --git a/backend/notification/migrations/__init__.py b/backend/notification/migrations/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/backend/notification/models.py b/backend/notification/models.py deleted file mode 100644 index 121ac0599..000000000 --- a/backend/notification/models.py +++ /dev/null @@ -1,91 +0,0 @@ -import uuid - -from api.models import APIDeployment -from django.db import models -from pipeline.models import Pipeline -from utils.models.base_model import BaseModel - -from .enums import AuthorizationType, NotificationType, PlatformType - -NOTIFICATION_NAME_MAX_LENGTH = 255 - - -class Notification(BaseModel): - id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) - name = models.CharField( - max_length=NOTIFICATION_NAME_MAX_LENGTH, - db_comment="Name of the notification.", - default="Notification", - ) - url = models.URLField(null=True) # URL for webhook or other endpoints - authorization_key = models.CharField( - max_length=255, blank=True, null=True - ) # Authorization Key or API Key - authorization_header = models.CharField( - max_length=255, blank=True, null=True - ) # Header Name for custom headers - authorization_type = models.CharField( - max_length=50, - choices=AuthorizationType.choices(), - default=AuthorizationType.NONE.value, - ) - max_retries = models.IntegerField( - default=0 - ) # Maximum number of times to retry webhook - platform = models.CharField( - max_length=50, - choices=PlatformType.choices(), - blank=True, - null=True, - ) - notification_type = models.CharField( - max_length=50, - choices=NotificationType.choices(), - default=NotificationType.WEBHOOK.value, - ) - is_active = models.BooleanField( - default=True, - db_comment="Flag indicating whether the notification is active or not.", - ) - # Foreign keys to specific models - pipeline = models.ForeignKey( - Pipeline, - on_delete=models.CASCADE, - related_name="notifications", - null=True, - blank=True, - ) - api = models.ForeignKey( - APIDeployment, - on_delete=models.CASCADE, - related_name="notifications", - null=True, - blank=True, - ) - - class Meta: - constraints = [ - models.UniqueConstraint( - fields=["name", "pipeline"], name="unique_name_pipeline" - ), - models.UniqueConstraint(fields=["name", "api"], name="unique_name_api"), - ] - - def save(self, *args, **kwargs): - # Validation for platforms - valid_platforms = NotificationType(self.notification_type).get_valid_platforms() - if self.platform and self.platform not in valid_platforms: - raise ValueError( - f"Invalid platform '{self.platform}' for notification type " - f"'{self.notification_type}'. " - f"Valid options are: {', '.join(valid_platforms)}." - ) - - # Allow saving only if the platform is valid or not required - super().save(*args, **kwargs) - - def __str__(self): - return ( - f"Notification {self.id}: (Type: {self.notification_type}, " - f"Platform: {self.platform}, Url: {self.url}))" - ) diff --git a/backend/notification/provider/__init__.py b/backend/notification/provider/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/backend/notification/provider/notification_provider.py b/backend/notification/provider/notification_provider.py deleted file mode 100644 index dbfeba284..000000000 --- a/backend/notification/provider/notification_provider.py +++ /dev/null @@ -1,28 +0,0 @@ -from abc import ABC, abstractmethod - -from django.conf import settings -from notification.models import Notification - - -class NotificationProvider(ABC): - NOTIFICATION_TIMEOUT = settings.NOTIFICATION_TIMEOUT - RETRY_DELAY = 10 # Seconds - - def __init__(self, notification: Notification, payload): - self.payload = payload - self.notification = notification - - @abstractmethod - def send(self): - """Method to be overridden in child classes for sending the - notification.""" - raise NotImplementedError("Subclasses should implement this method.") - - def validate(self): - """Method to validate the notification data.""" - pass - - @abstractmethod - def get_headers(self): - """Method to get the headers for the notification.""" - raise NotImplementedError("Subclasses should implement this method.") diff --git a/backend/notification/provider/registry.py b/backend/notification/provider/registry.py deleted file mode 100644 index f746c32a3..000000000 --- a/backend/notification/provider/registry.py +++ /dev/null @@ -1,54 +0,0 @@ -from notification.enums import NotificationType, PlatformType -from notification.provider.notification_provider import NotificationProvider -from notification.provider.webhook.api_webhook import APIWebhook -from notification.provider.webhook.slack_webhook import SlackWebhook - -REGISTRY = { - NotificationType.WEBHOOK: { - PlatformType.SLACK: SlackWebhook, - PlatformType.API: APIWebhook, - # Add other platform-specific classes here - }, - # Add other notification types and classes here -} - - -def get_notification_provider( - notification_type: NotificationType, platform_type: PlatformType -) -> NotificationProvider: - """Get Notification provider based on notification type and platform type - It uses the REGISTRY to map the combination of notification type and - platform type to the corresponding NotificationProvider class. - - If the provided combination is not found in the REGISTRY, a ValueError is raised. - - Note: - This function assumes that the REGISTRY dictionary is correctly populated - with the appropriate NotificationProvider classes for each combination of - notification type and platform type. - - See Also: - - NotificationType - - PlatformType - - NotificationProvider - - REGISTRY - - Parameters: - notification_type (NotificationType): The type of notification. - platform_type (PlatformType): The platform/provider type for the notification. - - Returns: - NotificationProvider: The appropriate NotificationProvider class for - the given combination. - - Raises: - ValueError: If the provided combination is not found in the REGISTRY. - """ - if notification_type not in REGISTRY: - raise ValueError(f"Unsupported notification type: {notification_type}") - - platform_registry = REGISTRY[notification_type] - if platform_type not in platform_registry: - raise ValueError(f"Unsupported platform type: {platform_type}") - - return platform_registry[platform_type] diff --git a/backend/notification/provider/webhook/__init__.py b/backend/notification/provider/webhook/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/backend/notification/provider/webhook/api_webhook.py b/backend/notification/provider/webhook/api_webhook.py deleted file mode 100644 index dc53231d9..000000000 --- a/backend/notification/provider/webhook/api_webhook.py +++ /dev/null @@ -1,13 +0,0 @@ -from notification.provider.webhook.webhook import Webhook - - -class APIWebhook(Webhook): - def send(self): - """Send the API webhook notification.""" - super().send() - - def get_headers(self): - """API-specific headers.""" - headers = super().get_headers() - headers["Content-Type"] = "application/json" - return headers diff --git a/backend/notification/provider/webhook/slack_webhook.py b/backend/notification/provider/webhook/slack_webhook.py deleted file mode 100644 index b529d510a..000000000 --- a/backend/notification/provider/webhook/slack_webhook.py +++ /dev/null @@ -1,60 +0,0 @@ -import logging - -from notification.provider.webhook.webhook import Webhook - -logger = logging.getLogger(__name__) - - -class SlackWebhook(Webhook): - def send(self): - """Send the Slack webhook notification.""" - formatted_payload = self.format_payload() - self.payload = formatted_payload - super().send() - - def get_headers(self): - """Slack-specific headers.""" - headers = super().get_headers() - headers["Content-Type"] = "application/json" - return headers - - def format_payload(self) -> dict: - """Format the payload to match Slack's expected structure.""" - if "text" not in self.payload: - # Construct a basic Slack message with 'text' field - formatted_payload = { - "text": "Notification", - "blocks": self.create_blocks_from_payload(), - } - else: - # If 'text' is already present, format accordingly - formatted_payload = { - "text": self.payload.pop("text"), - "blocks": self.create_blocks_from_payload(), - } - return formatted_payload - - def create_blocks_from_payload(self) -> list: - """Create Slack blocks from the given payload.""" - blocks = [] - # Header - blocks.append( - { - "type": "section", - "text": {"type": "mrkdwn", "text": "*Unstract Update:*"}, - } - ) - # Add a divider for separation - blocks.append({"type": "divider"}) - # Add each key-value pair to the blocks - for key, value in self.payload.items(): - formatted_key = key.replace("_", " ").title() - blocks.append( - { - "type": "section", - "text": {"type": "mrkdwn", "text": f"*{formatted_key}:* {value}"}, - } - ) - # Footer - blocks.append({"type": "section", "text": {"type": "mrkdwn", "text": "*---*"}}) - return blocks diff --git a/backend/notification/provider/webhook/webhook.py b/backend/notification/provider/webhook/webhook.py deleted file mode 100644 index e90f06238..000000000 --- a/backend/notification/provider/webhook/webhook.py +++ /dev/null @@ -1,156 +0,0 @@ -# notifications.py - -import logging -from typing import Any, Optional - -import requests -from celery import shared_task -from notification.enums import AuthorizationType -from notification.provider.notification_provider import NotificationProvider - -logger = logging.getLogger(__name__) - - -class WebhookNotificationArg: - MAX_RETRIES = "max_retries" - RETRY_DELAY = "retry_delay" - - -class HeaderConstants: - APPLICATION_JSON = "application/json" - - -class Webhook(NotificationProvider): - def send(self): - """Send the webhook notification.""" - try: - headers = self.get_headers() - self.validate() - except ValueError as e: - logger.error(f"Error validating notification {self.notification} :: {e}") - return - send_webhook_notification.apply_async( - (self.notification.url, self.payload, headers, self.NOTIFICATION_TIMEOUT), - kwargs={ - WebhookNotificationArg.MAX_RETRIES: self.notification.max_retries, - WebhookNotificationArg.RETRY_DELAY: self.RETRY_DELAY, - }, - ) - - def validate(self): - """Validate notification. - - Returns: - _type_: None - """ - if not self.notification.url: - raise ValueError("Webhook URL is required.") - if not self.payload: - raise ValueError("Payload is required.") - return super().validate() - - def get_headers(self): - """ - Get the headers for the notification based on the authorization type and key. - Raises: - ValueError: _description_ - - Returns: - dict[str, str]: A dictionary containing the headers. - """ - headers = {} - try: - authorization_type = AuthorizationType( - self.notification.authorization_type.upper() - ) - except ValueError: - raise ValueError( - "Unsupported authorization type: " - f"{self.notification.authorization_type}" - ) - authorization_key = self.notification.authorization_key - authorization_header = self.notification.authorization_header - - header_formats = { - AuthorizationType.BEARER: lambda key: { - "Authorization": f"Bearer {key}", - "Content-Type": HeaderConstants.APPLICATION_JSON, - }, - AuthorizationType.API_KEY: lambda key: { - "Authorization": key, - "Content-Type": HeaderConstants.APPLICATION_JSON, - }, - AuthorizationType.CUSTOM_HEADER: lambda key: { - authorization_header: key, - "Content-Type": HeaderConstants.APPLICATION_JSON, - }, - AuthorizationType.NONE: lambda _: { - "Content-Type": HeaderConstants.APPLICATION_JSON, - }, - } - - if authorization_type not in header_formats: - raise ValueError(f"Unsupported authorization type: {authorization_type}") - - headers = header_formats[authorization_type](authorization_key) - - # Check if custom header type has required details - if authorization_type == AuthorizationType.CUSTOM_HEADER: - if not authorization_header or not authorization_key: - raise ValueError( - "Custom header or key missing for custom authorization." - ) - return headers - - -@shared_task(bind=True, name="send_webhook_notification") -def send_webhook_notification( - self, - url: str, - payload: Any, - headers: Any = None, - timeout: int = 10, - max_retries: Optional[int] = None, - retry_delay: int = 10, -): - """Celery task to send a webhook with retries and error handling. - - Args: - url (str): The URL to which the webhook should be sent. - payload (dict): The payload to be sent in the webhook request. - headers (dict, optional): Optional headers to include in the request. - Defaults to None. - timeout (int, optional): The request timeout in seconds. Defaults to 10. - max_retries (int, optional): The maximum number of retries allowed. - Defaults to None. - retry_delay (int, optional): The delay between retries in seconds. - Defaults to 10. - - Returns: - None - """ - try: - response = requests.post(url, json=payload, headers=headers, timeout=timeout) - response.raise_for_status() - if not (200 <= response.status_code < 300): - logger.error( - f"Request to {url} failed with status code {response.status_code}. " - f"Response: {response.text}" - ) - except requests.exceptions.RequestException as exc: - if max_retries is not None: - if self.request.retries < max_retries: - logger.warning( - f"Request to {url} failed. Retrying in {retry_delay} seconds. " - f"Attempt {self.request.retries + 1}/{max_retries}. Error: {exc}" - ) - raise self.retry(exc=exc, countdown=retry_delay) - else: - logger.error( - f"Failed to send webhook to {url} after {max_retries} attempts. " - f"Error: {exc}" - ) - return None - else: - logger.error(f"Webhook request to {url} failed with error: {exc}") - return None diff --git a/backend/notification/serializers.py b/backend/notification/serializers.py deleted file mode 100644 index a5e405a52..000000000 --- a/backend/notification/serializers.py +++ /dev/null @@ -1,130 +0,0 @@ -from rest_framework import serializers - -from .enums import AuthorizationType, NotificationType, PlatformType -from .models import Notification - - -class NotificationSerializer(serializers.ModelSerializer): - notification_type = serializers.ChoiceField(choices=NotificationType.choices()) - authorization_type = serializers.ChoiceField(choices=AuthorizationType.choices()) - platform = serializers.ChoiceField(choices=PlatformType.choices(), required=False) - max_retries = serializers.IntegerField( - max_value=4, min_value=0, default=0, required=False - ) - - class Meta: - model = Notification - fields = "__all__" - - def validate(self, data): - """Validate the data for the NotificationSerializer.""" - # General validation for the relationship between api and pipeline - self._validate_api_or_pipeline(data) - self._validate_authorization(data) - return data - - def _validate_api_or_pipeline(self, data): - """Ensure either 'api' or 'pipeline' is provided, but not both.""" - api = data.get("api", getattr(self.instance, "api", None)) - pipeline = data.get("pipeline", getattr(self.instance, "pipeline", None)) - if api and pipeline: - raise serializers.ValidationError( - "Only one of 'api' or 'pipeline' can be provided." - ) - - if not api and not pipeline: - raise serializers.ValidationError( - "Either 'api' or 'pipeline' must be provided." - ) - - def _validate_authorization(self, data): - """Ensure required authorization fields are provided based on the - authorization type. - - Getting existing data in the case of PATCH request - """ - authorization_type = data.get( - "authorization_type", getattr(self.instance, "authorization_type", None) - ) - authorization_key = data.get( - "authorization_key", getattr(self.instance, "authorization_key", None) - ) - authorization_header = data.get( - "authorization_header", getattr(self.instance, "authorization_header", None) - ) - - try: - authorization_type_enum = AuthorizationType(authorization_type) - except ValueError: - raise serializers.ValidationError( - f"Invalid authorization type '{authorization_type}'." - ) - - if authorization_type_enum in [ - AuthorizationType.BEARER, - AuthorizationType.API_KEY, - AuthorizationType.CUSTOM_HEADER, - ]: - if not authorization_key: - raise serializers.ValidationError( - { - "authorization_key": ( - "Authorization key is required for authorization " - f"type '{authorization_type_enum.value}'." - ) - } - ) - - if ( - authorization_type_enum == AuthorizationType.CUSTOM_HEADER - and not authorization_header - ): - raise serializers.ValidationError( - { - "authorization_header": ( - "Authorization header is required when using " - "CUSTOM_HEADER authorization type." - ) - } - ) - - def validate_platform(self, value): - """Validate the platform field based on the notification_type.""" - notification_type = self.initial_data.get( - "notification_type", getattr(self.instance, "notification_type", None) - ) - if not notification_type: - raise serializers.ValidationError("Notification type must be provided.") - - valid_platforms = NotificationType(notification_type).get_valid_platforms() - if value and value not in valid_platforms: - raise serializers.ValidationError( - f"Invalid platform '{value}' for notification type " - f"'{notification_type}'. " - f"Valid options are: {', '.join(valid_platforms)}." - ) - return value - - def validate_name(self, value): - """Check uniqueness of the name with respect to either 'api' or - 'pipeline'.""" - api = self.initial_data.get("api", getattr(self.instance, "api", None)) - pipeline = self.initial_data.get( - "pipeline", getattr(self.instance, "pipeline", None) - ) - - queryset = Notification.objects.filter(name=value) - if self.instance: - queryset = queryset.exclude(id=self.instance.id) - - if api and queryset.filter(api=api).exists(): - raise serializers.ValidationError( - "A notification with this name and API already exists.", - code="unique_api", - ) - elif pipeline and queryset.filter(pipeline=pipeline).exists(): - raise serializers.ValidationError( - "A notification with this name and pipeline already exists.", - code="unique_pipeline", - ) - return value diff --git a/backend/notification/urls.py b/backend/notification/urls.py deleted file mode 100644 index 2e356b400..000000000 --- a/backend/notification/urls.py +++ /dev/null @@ -1,27 +0,0 @@ -from django.urls import path -from rest_framework.urlpatterns import format_suffix_patterns - -from .views import NotificationViewSet - -notification_list = NotificationViewSet.as_view({"get": "list", "post": "create"}) -notification_detail = NotificationViewSet.as_view( - { - "get": "retrieve", - "put": "update", - "patch": "partial_update", - "delete": "destroy", - } -) - -urlpatterns = format_suffix_patterns( - [ - path("", notification_list, name="notification-list"), - path("/", notification_detail, name="notification-detail"), - path( - "pipeline//", - notification_list, - name="pipeline-notification-list", - ), - path("api//", notification_list, name="api-notification-list"), - ] -) diff --git a/backend/notification/views.py b/backend/notification/views.py deleted file mode 100644 index beb5e75a2..000000000 --- a/backend/notification/views.py +++ /dev/null @@ -1,36 +0,0 @@ -from api.deployment_helper import DeploymentHelper -from api.exceptions import APINotFound -from notification.constants import NotificationUrlConstant -from pipeline.exceptions import PipelineNotFound -from pipeline.models import Pipeline -from pipeline.pipeline_processor import PipelineProcessor -from rest_framework import viewsets - -from .models import Notification -from .serializers import NotificationSerializer - - -class NotificationViewSet(viewsets.ModelViewSet): - serializer_class = NotificationSerializer - - def get_queryset(self): - queryset = Notification.objects.all() - pipeline_uuid = self.kwargs.get(NotificationUrlConstant.PIPELINE_UID) - api_uuid = self.kwargs.get(NotificationUrlConstant.API_UID) - - if pipeline_uuid: - try: - pipeline = PipelineProcessor.fetch_pipeline( - pipeline_id=pipeline_uuid, check_active=False - ) - queryset = queryset.filter(pipeline=pipeline) - except Pipeline.DoesNotExist: - raise PipelineNotFound() - - elif api_uuid: - api = DeploymentHelper.get_api_by_id(api_id=api_uuid) - if not api: - raise APINotFound() - queryset = queryset.filter(api=api) - - return queryset diff --git a/backend/pipeline/__init__.py b/backend/pipeline/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/backend/pipeline/constants.py b/backend/pipeline/constants.py deleted file mode 100644 index cbb78c79a..000000000 --- a/backend/pipeline/constants.py +++ /dev/null @@ -1,66 +0,0 @@ -class PipelineConstants: - """Constants for Pipelines.""" - - TYPE = "type" - ETL_PIPELINE = "ETL" - TASK_PIPELINE = "TASK" - ETL = "etl" - TASK = "task" - CREATE_ACTION = "create" - UPDATE_ACTION = "update" - PIPELINE_GUID = "id" - ACTION = "action" - NOT_CONFIGURED = "Connector not configured." - SOURCE_NOT_CONFIGURED = "Source not configured." - DESTINATION_NOT_CONFIGURED = "Destination not configured." - SOURCE_ICON = "source_icon" - DESTINATION_ICON = "destination_icon" - SOURCE_NAME = "source_name" - DESTINATION_NAME = "destination_name" - INPUT_FILE = "input_file_connector" - INPUT_DB = "input_db_connector" - OUTPUT_FILE = "output_file_connector" - OUTPUT_DB = "output_db_connector" - SOURCE = "source" - DEST = "dest" - - -class PipelineExecutionKey: - PIPELINE = "pipeline" - EXECUTION = "execution" - - -class PipelineKey: - """Constants for the Pipeline model.""" - - PIPELINE_GUID = "id" - PIPELINE_NAME = "pipeline_name" - WORKFLOW = "workflow" - APP_ID = "app_id" - ACTIVE = "active" - SCHEDULED = "scheduled" - PIPELINE_TYPE = "pipeline_type" - RUN_COUNT = "run_count" - LAST_RUN_TIME = "last_run_time" - LAST_RUN_STATUS = "last_run_status" - # Used by serializer - CRON_DATA = "cron_data" - WORKFLOW_NAME = "workflow_name" - WORKFLOW_ID = "workflow_id" - CRON_STRING = "cron_string" - PIPELINE_ID = "pipeline_id" - - -class PipelineErrors: - PIPELINE_EXISTS = "Pipeline with this configuration might already exist or some mandatory field is missing." # noqa: E501 - DUPLICATE_API = "It appears that a duplicate call may have been made." - INVALID_WF = "The provided workflow does not exist" - - -class PipelineURL: - """Constants for URL names.""" - - DETAIL = "pipeline-detail" - EXECUTIONS = "pipeline-executions" - LIST = "pipeline-list" - EXECUTE = "pipeline-execute" diff --git a/backend/pipeline/deployment_helper.py b/backend/pipeline/deployment_helper.py deleted file mode 100644 index 1044220c1..000000000 --- a/backend/pipeline/deployment_helper.py +++ /dev/null @@ -1,33 +0,0 @@ -import logging -from typing import Any - -from api.api_key_validator import BaseAPIKeyValidator -from api.exceptions import InvalidAPIRequest -from api.key_helper import KeyHelper -from pipeline.exceptions import PipelineNotFound -from pipeline.pipeline_processor import PipelineProcessor -from rest_framework.request import Request - -logger = logging.getLogger(__name__) - - -class DeploymentHelper(BaseAPIKeyValidator): - @staticmethod - def validate_parameters(request: Request, **kwargs: Any) -> None: - """Validate pipeline_id for pipeline deployments.""" - pipeline_id = kwargs.get("pipeline_id") or request.data.get("pipeline_id") - if not pipeline_id: - raise InvalidAPIRequest("Missing params pipeline_id") - - @staticmethod - def validate_and_process( - self: Any, request: Request, func: Any, api_key: str, *args: Any, **kwargs: Any - ) -> Any: - """Fetch pipeline and validate API key.""" - pipeline_id = kwargs.get("pipeline_id") or request.data.get("pipeline_id") - pipeline = PipelineProcessor.get_active_pipeline(pipeline_id=pipeline_id) - if not pipeline: - raise PipelineNotFound() - KeyHelper.validate_api_key(api_key=api_key, instance=pipeline) - kwargs["pipeline"] = pipeline - return func(self, request, *args, **kwargs) diff --git a/backend/pipeline/dto.py b/backend/pipeline/dto.py deleted file mode 100644 index 3ed83bf30..000000000 --- a/backend/pipeline/dto.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import Any, Optional - - -class PipelineStatusPayload: - def __init__( - self, - type: str, - pipeline_id: str, - pipeline_name: str, - status: str, - execution_id: Optional[str] = None, - error_message: Optional[str] = None, - ): - self.type = type - self.pipeline_id = pipeline_id - self.pipeline_name = pipeline_name - self.status = status - self.execution_id = execution_id - self.error_message = error_message - - def to_dict(self) -> dict[str, Any]: - """Convert the payload DTO to a dictionary.""" - payload = { - "type": self.type, - "pipeline_id": str(self.pipeline_id), - "pipeline_name": self.pipeline_name, - "status": self.status, - } - if self.execution_id: - payload["execution_id"] = str(self.execution_id) - if self.error_message: - payload["error_message"] = self.error_message - return payload diff --git a/backend/pipeline/exceptions.py b/backend/pipeline/exceptions.py deleted file mode 100644 index 7d2885a98..000000000 --- a/backend/pipeline/exceptions.py +++ /dev/null @@ -1,55 +0,0 @@ -from typing import Optional - -from rest_framework.exceptions import APIException - - -class NotFoundException(APIException): - status_code = 404 - default_detail = "The requested resource was not found." - - -class WorkflowTriggerError(APIException): - status_code = 400 - default_detail = "Error triggering workflow. Pipeline created" - - -class PipelineExecuteError(APIException): - status_code = 500 - default_detail = "Error executing pipline" - - -class InactivePipelineError(APIException): - status_code = 422 - default_detail = "Pipeline is inactive, please activate the pipeline" - - def __init__( - self, - pipeline_name: Optional[str] = None, - detail: Optional[str] = None, - code: Optional[str] = None, - ): - if pipeline_name: - self.default_detail = ( - f"Pipeline '{pipeline_name}' is inactive, " - "please activate the pipeline" - ) - super().__init__(detail, code) - - -class MandatoryPipelineType(APIException): - status_code = 400 - default_detail = "Pipeline type is mandatory" - - -class MandatoryWorkflowId(APIException): - status_code = 400 - default_detail = "Workflow ID is mandatory" - - -class MandatoryCronSchedule(APIException): - status_code = 400 - default_detail = "Cron schedule is mandatory" - - -class PipelineNotFound(NotFoundException): - default_detail = "Pipeline not found" diff --git a/backend/pipeline/execution_view.py b/backend/pipeline/execution_view.py deleted file mode 100644 index 45f83f438..000000000 --- a/backend/pipeline/execution_view.py +++ /dev/null @@ -1,35 +0,0 @@ -from permissions.permission import IsOwner -from pipeline.serializers.execute import DateRangeSerializer -from rest_framework import viewsets -from rest_framework.versioning import URLPathVersioning -from utils.pagination import CustomPagination -from workflow_manager.workflow.models.execution import WorkflowExecution -from workflow_manager.workflow.serializers import WorkflowExecutionSerializer - - -class PipelineExecutionViewSet(viewsets.ModelViewSet): - versioning_class = URLPathVersioning - permission_classes = [IsOwner] - serializer_class = WorkflowExecutionSerializer - pagination_class = CustomPagination - - CREATED_AT_FIELD_DESC = "-created_at" - START_DATE_FIELD = "start_date" - END_DATE_FIELD = "end_date" - - def get_queryset(self): - # Get the pipeline_id from the URL path - pipeline_id = self.kwargs.get("pk") - queryset = WorkflowExecution.objects.filter(pipeline_id=pipeline_id) - - # Validate start_date and end_date parameters using DateRangeSerializer - date_range_serializer = DateRangeSerializer(data=self.request.query_params) - date_range_serializer.is_valid(raise_exception=True) - start_date = date_range_serializer.validated_data.get(self.START_DATE_FIELD) - end_date = date_range_serializer.validated_data.get(self.END_DATE_FIELD) - - if start_date and end_date: - queryset = queryset.filter(created_at__range=(start_date, end_date)) - - queryset = queryset.order_by(self.CREATED_AT_FIELD_DESC) - return queryset diff --git a/backend/pipeline/manager.py b/backend/pipeline/manager.py deleted file mode 100644 index f6a339210..000000000 --- a/backend/pipeline/manager.py +++ /dev/null @@ -1,59 +0,0 @@ -import logging -from typing import Any, Optional - -from django.conf import settings -from django.urls import reverse -from pipeline.constants import PipelineKey, PipelineURL -from pipeline.models import Pipeline -from pipeline.pipeline_processor import PipelineProcessor -from rest_framework.request import Request -from rest_framework.response import Response -from utils.request.constants import RequestConstants -from workflow_manager.workflow.constants import WorkflowExecutionKey, WorkflowKey -from workflow_manager.workflow.views import WorkflowViewSet - -from backend.constants import RequestHeader - -logger = logging.getLogger(__name__) - - -class PipelineManager: - """Helps manage the execution and scheduling of pipelines.""" - - @staticmethod - def execute_pipeline( - request: Request, - pipeline_id: str, - execution_id: Optional[str] = None, - ) -> Response: - """Used to execute a pipeline. - - Args: - pipeline_id (str): UUID of the pipeline to execute - execution_id (Optional[str], optional): - Uniquely identifies an execution. Defaults to None. - """ - logger.info(f"Executing pipeline {pipeline_id}, execution: {execution_id}") - pipeline: Pipeline = PipelineProcessor.initialize_pipeline_sync(pipeline_id) - # TODO: Use DRF's request and as_view() instead - request.data[WorkflowKey.WF_ID] = pipeline.workflow.id - if execution_id is not None: - request.data[WorkflowExecutionKey.EXECUTION_ID] = execution_id - wf_viewset = WorkflowViewSet() - return wf_viewset.execute(request=request, pipeline_guid=str(pipeline.pk)) - - @staticmethod - def get_pipeline_execution_data_for_scheduled_run( - pipeline_id: str, - ) -> Optional[dict[str, Any]]: - """Gets the required data to be passed while executing a pipeline Any - changes to pipeline execution needs to be propagated here.""" - callback_url = settings.DJANGO_APP_BACKEND_URL + reverse(PipelineURL.EXECUTE) - job_headers = {RequestHeader.X_API_KEY: settings.INTERNAL_SERVICE_API_KEY} - job_kwargs = { - RequestConstants.VERB: "POST", - RequestConstants.URL: callback_url, - RequestConstants.HEADERS: job_headers, - RequestConstants.DATA: {PipelineKey.PIPELINE_ID: pipeline_id}, - } - return job_kwargs diff --git a/backend/pipeline/migrations/0001_initial.py b/backend/pipeline/migrations/0001_initial.py deleted file mode 100644 index 67c4d4ca7..000000000 --- a/backend/pipeline/migrations/0001_initial.py +++ /dev/null @@ -1,130 +0,0 @@ -# Generated by Django 4.2.1 on 2024-01-23 11:18 - -import uuid - -import django.db.models.deletion -from django.conf import settings -from django.db import migrations, models - - -class Migration(migrations.Migration): - initial = True - - dependencies = [ - migrations.swappable_dependency(settings.AUTH_USER_MODEL), - ("workflow", "0001_initial"), - ] - - operations = [ - migrations.CreateModel( - name="Pipeline", - fields=[ - ("created_at", models.DateTimeField(auto_now_add=True)), - ("modified_at", models.DateTimeField(auto_now=True)), - ( - "id", - models.UUIDField( - default=uuid.uuid4, - editable=False, - primary_key=True, - serialize=False, - ), - ), - ( - "pipeline_name", - models.CharField(default="", max_length=32, unique=True), - ), - ( - "app_id", - models.TextField(blank=True, max_length=32, null=True), - ), - ("active", models.BooleanField(default=False)), - ("scheduled", models.BooleanField(default=False)), - ( - "cron_string", - models.TextField(db_comment="UNIX cron string", max_length=256), - ), - ( - "pipeline_type", - models.CharField( - choices=[ - ("ETL", "ETL"), - ("TASK", "TASK"), - ("DEFAULT", "Default"), - ("APP", "App"), - ], - default="DEFAULT", - ), - ), - ("run_count", models.IntegerField(default=0)), - ("last_run_time", models.DateTimeField(blank=True, null=True)), - ( - "last_run_status", - models.CharField( - choices=[ - ("SUCCESS", "Success"), - ("FAILURE", "Failure"), - ("INPROGRESS", "Inprogress"), - ("YET_TO_START", "Yet to start"), - ("RESTARTING", "Restarting"), - ], - default="YET_TO_START", - ), - ), - ( - "app_icon", - models.URLField( - blank=True, - db_comment="Field to store icon url for Apps", - null=True, - ), - ), - ( - "app_url", - models.URLField( - blank=True, - db_comment="Stores deployed URL for App", - null=True, - ), - ), - ( - "access_control_bundle_id", - models.TextField(blank=True, null=True), - ), - ( - "created_by", - models.ForeignKey( - blank=True, - null=True, - on_delete=django.db.models.deletion.SET_NULL, - related_name="created_pipeline", - to=settings.AUTH_USER_MODEL, - ), - ), - ( - "modified_by", - models.ForeignKey( - blank=True, - null=True, - on_delete=django.db.models.deletion.SET_NULL, - related_name="modified_pipeline", - to=settings.AUTH_USER_MODEL, - ), - ), - ( - "workflow", - models.ForeignKey( - on_delete=django.db.models.deletion.CASCADE, - related_name="pipeline_workflows", - to="workflow.workflow", - ), - ), - ], - ), - migrations.AddConstraint( - model_name="pipeline", - constraint=models.UniqueConstraint( - fields=("id", "pipeline_type"), name="unique_pipeline" - ), - ), - ] diff --git a/backend/pipeline/migrations/0002_alter_pipeline_last_run_status.py b/backend/pipeline/migrations/0002_alter_pipeline_last_run_status.py deleted file mode 100644 index 9274e6b63..000000000 --- a/backend/pipeline/migrations/0002_alter_pipeline_last_run_status.py +++ /dev/null @@ -1,27 +0,0 @@ -# Generated by Django 4.2.1 on 2024-03-01 06:25 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("pipeline", "0001_initial"), - ] - - operations = [ - migrations.AlterField( - model_name="pipeline", - name="last_run_status", - field=models.CharField( - choices=[ - ("SUCCESS", "Success"), - ("FAILURE", "Failure"), - ("INPROGRESS", "Inprogress"), - ("YET_TO_START", "Yet to start"), - ("RESTARTING", "Restarting"), - ("PAUSED", "Paused"), - ], - default="YET_TO_START", - ), - ), - ] diff --git a/backend/pipeline/migrations/0003_alter_pipeline_active_alter_pipeline_cron_string_and_more.py b/backend/pipeline/migrations/0003_alter_pipeline_active_alter_pipeline_cron_string_and_more.py deleted file mode 100644 index 86a825f0a..000000000 --- a/backend/pipeline/migrations/0003_alter_pipeline_active_alter_pipeline_cron_string_and_more.py +++ /dev/null @@ -1,34 +0,0 @@ -# Generated by Django 4.2.1 on 2024-07-31 14:45 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ("pipeline", "0002_alter_pipeline_last_run_status"), - ] - - operations = [ - migrations.AlterField( - model_name="pipeline", - name="active", - field=models.BooleanField( - db_comment="Indicates whether the pipeline is active", default=False - ), - ), - migrations.AlterField( - model_name="pipeline", - name="cron_string", - field=models.TextField( - db_comment="UNIX cron string", max_length=256, null=True - ), - ), - migrations.AlterField( - model_name="pipeline", - name="scheduled", - field=models.BooleanField( - db_comment="Indicates whether the pipeline is scheduled", default=False - ), - ), - ] diff --git a/backend/pipeline/migrations/0004_alter_pipeline_cron_string.py b/backend/pipeline/migrations/0004_alter_pipeline_cron_string.py deleted file mode 100644 index 8812b9a48..000000000 --- a/backend/pipeline/migrations/0004_alter_pipeline_cron_string.py +++ /dev/null @@ -1,20 +0,0 @@ -# Generated by Django 4.2.1 on 2024-08-24 10:45 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ("pipeline", "0003_alter_pipeline_active_alter_pipeline_cron_string_and_more"), - ] - - operations = [ - migrations.AlterField( - model_name="pipeline", - name="cron_string", - field=models.TextField( - blank=True, db_comment="UNIX cron string", max_length=256, null=True - ), - ), - ] diff --git a/backend/pipeline/migrations/__init__.py b/backend/pipeline/migrations/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/backend/pipeline/models.py b/backend/pipeline/models.py deleted file mode 100644 index bf0dc5494..000000000 --- a/backend/pipeline/models.py +++ /dev/null @@ -1,117 +0,0 @@ -import uuid - -from account.models import User -from django.conf import settings -from django.db import connection, models -from utils.models.base_model import BaseModel -from workflow_manager.workflow.models.workflow import Workflow - -from backend.constants import FieldLengthConstants as FieldLength - -APP_ID_LENGTH = 32 -PIPELINE_NAME_LENGTH = 32 - - -class Pipeline(BaseModel): - """Model to hold data related to Pipelines.""" - - class PipelineType(models.TextChoices): - ETL = "ETL", "ETL" - TASK = "TASK", "TASK" - DEFAULT = "DEFAULT", "Default" - APP = "APP", "App" - - class PipelineStatus(models.TextChoices): - SUCCESS = "SUCCESS", "Success" - FAILURE = "FAILURE", "Failure" - INPROGRESS = "INPROGRESS", "Inprogress" - YET_TO_START = "YET_TO_START", "Yet to start" - RESTARTING = "RESTARTING", "Restarting" - PAUSED = "PAUSED", "Paused" - - id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) - pipeline_name = models.CharField( - max_length=PIPELINE_NAME_LENGTH, default="", unique=True - ) - workflow = models.ForeignKey( - Workflow, - on_delete=models.CASCADE, - related_name="pipeline_workflows", - null=False, - blank=False, - ) - # Added as text field until a model for App is included. - app_id = models.TextField(null=True, blank=True, max_length=APP_ID_LENGTH) - active = models.BooleanField( - default=False, db_comment="Indicates whether the pipeline is active" - ) - scheduled = models.BooleanField( - default=False, db_comment="Indicates whether the pipeline is scheduled" - ) - cron_string = models.TextField( - db_comment="UNIX cron string", - null=True, - blank=True, - max_length=FieldLength.CRON_LENGTH, - ) - pipeline_type = models.CharField( - choices=PipelineType.choices, default=PipelineType.DEFAULT - ) - run_count = models.IntegerField(default=0) - last_run_time = models.DateTimeField(null=True, blank=True) - last_run_status = models.CharField( - choices=PipelineStatus.choices, default=PipelineStatus.YET_TO_START - ) - app_icon = models.URLField( - null=True, blank=True, db_comment="Field to store icon url for Apps" - ) - app_url = models.URLField( - null=True, blank=True, db_comment="Stores deployed URL for App" - ) - # TODO: Change this to a Forgein key once the bundle is created. - access_control_bundle_id = models.TextField(null=True, blank=True) - created_by = models.ForeignKey( - User, - on_delete=models.SET_NULL, - related_name="created_pipeline", - null=True, - blank=True, - ) - modified_by = models.ForeignKey( - User, - on_delete=models.SET_NULL, - related_name="modified_pipeline", - null=True, - blank=True, - ) - - @property - def api_key_data(self): - return {"pipeline": self.id, "description": f"API Key for {self.pipeline_name}"} - - @property - def api_endpoint(self): - org_schema = connection.tenant.schema_name - deployment_endpoint = settings.API_DEPLOYMENT_PATH_PREFIX + "/pipeline/api" - api_endpoint = f"{deployment_endpoint}/{org_schema}/{self.id}/" - return api_endpoint - - def __str__(self) -> str: - return ( - f"Pipeline({self.id}) (" - f"name: {self.pipeline_name}, " - f"cron string: {self.cron_string}, " - f"is active: {self.active}, " - f"is scheduled: {self.scheduled}" - ) - - class Meta: - constraints = [ - models.UniqueConstraint( - fields=["id", "pipeline_type"], - name="unique_pipeline", - ), - ] - - def is_active(self) -> bool: - return bool(self.active) diff --git a/backend/pipeline/notification.py b/backend/pipeline/notification.py deleted file mode 100644 index bc08f63a4..000000000 --- a/backend/pipeline/notification.py +++ /dev/null @@ -1,44 +0,0 @@ -import logging -from typing import Optional - -from notification.helper import NotificationHelper -from notification.models import Notification -from pipeline.dto import PipelineStatusPayload -from pipeline.models import Pipeline - -logger = logging.getLogger(__name__) - - -class PipelineNotification: - def __init__( - self, - pipeline: Pipeline, - execution_id: Optional[str] = None, - error_message: Optional[str] = None, - ) -> None: - self.notifications = Notification.objects.filter( - pipeline=pipeline, is_active=True - ) - self.pipeline = pipeline - self.error_message = error_message - self.execution_id = execution_id - - def send(self): - if not self.notifications.count(): - logger.info(f"No notifications found for pipeline {self.pipeline}") - return - logger.info( - f"Sending pipeline status notification for pipeline {self.pipeline}" - ) - payload_dto = PipelineStatusPayload( - type=self.pipeline.pipeline_type, - pipeline_id=str(self.pipeline.id), - pipeline_name=self.pipeline.pipeline_name, - status=self.pipeline.last_run_status, - execution_id=self.execution_id, - error_message=self.error_message, - ) - - NotificationHelper.send_notification( - notifications=self.notifications, payload=payload_dto.to_dict() - ) diff --git a/backend/pipeline/piepline_api_execution_views.py b/backend/pipeline/piepline_api_execution_views.py deleted file mode 100644 index 763a2963c..000000000 --- a/backend/pipeline/piepline_api_execution_views.py +++ /dev/null @@ -1,43 +0,0 @@ -import logging -from typing import Any - -from pipeline.deployment_helper import DeploymentHelper -from pipeline.models import Pipeline -from rest_framework import status, views -from rest_framework.request import Request -from rest_framework.response import Response -from scheduler.tasks import execute_pipeline_task - -logger = logging.getLogger(__name__) - - -class PipelineApiExecution(views.APIView): - def initialize_request( - self, request: Request, *args: Any, **kwargs: Any - ) -> Request: - """To remove csrf request for public API. - - Args: - request (Request): _description_ - - Returns: - Request: _description_ - """ - setattr(request, "csrf_processing_done", True) - return super().initialize_request(request, *args, **kwargs) - - @DeploymentHelper.validate_api_key - def post( - self, request: Request, org_name: str, pipeline_id: str, pipeline: Pipeline - ) -> Response: - execute_pipeline_task.delay( - workflow_id="", - org_schema=org_name, - execution_action="", - execution_id="", - pipepline_id=pipeline_id, - with_logs=True, - name=pipeline.pipeline_name, - ) - logger.info(f"Triggered {pipeline} by API") - return Response({"message": f"Triggered {pipeline}"}, status=status.HTTP_200_OK) diff --git a/backend/pipeline/pipeline_processor.py b/backend/pipeline/pipeline_processor.py deleted file mode 100644 index be1e19cc7..000000000 --- a/backend/pipeline/pipeline_processor.py +++ /dev/null @@ -1,117 +0,0 @@ -import logging -from typing import Optional - -from django.utils import timezone -from pipeline.exceptions import InactivePipelineError -from pipeline.models import Pipeline -from pipeline.notification import PipelineNotification - -logger = logging.getLogger(__name__) - - -class PipelineProcessor: - @staticmethod - def initialize_pipeline_sync(pipeline_id: str) -> Pipeline: - """Fetches and initializes the sync for a pipeline. - - Args: - pipeline_id (str): UUID of the pipeline to sync - """ - pipeline: Pipeline = PipelineProcessor.fetch_pipeline(pipeline_id) - pipeline.run_count = pipeline.run_count + 1 - return PipelineProcessor._update_pipeline_status( - pipeline=pipeline, - status=Pipeline.PipelineStatus.RESTARTING, - is_end=False, - ) - - @staticmethod - def fetch_pipeline(pipeline_id: str, check_active: bool = True) -> Pipeline: - """Retrieves and checks for an active pipeline. - Args: - pipeline_id (str): UUID of the pipeline - check_active (bool): Whether to check if the pipeline is active - - Raises: - InactivePipelineError: If an active pipeline is not found - """ - pipeline: Pipeline = Pipeline.objects.get(pk=pipeline_id) - if check_active and not pipeline.is_active(): - logger.error(f"Inactive pipeline fetched: {pipeline_id}") - raise InactivePipelineError(pipeline_name=pipeline.pipeline_name) - return pipeline - - @classmethod - def get_active_pipeline(cls, pipeline_id: str) -> Optional[Pipeline]: - """Retrieves a list of active pipelines.""" - try: - return cls.fetch_pipeline(pipeline_id, check_active=True) - except Pipeline.DoesNotExist: - return None - - @staticmethod - def _update_pipeline_status( - pipeline: Pipeline, - status: tuple[str, str], - is_end: bool, - is_active: Optional[bool] = None, - ) -> Pipeline: - """Updates pipeline status during execution. - - Raises: - PipelineSaveError: Exception while saving a pipeline - - Returns: - Pipeline: Updated pipeline - """ - if is_end: - pipeline.last_run_time = timezone.now() - if status: - pipeline.last_run_status = status - if is_active is not None: - pipeline.active = is_active - - pipeline.save() - return pipeline - - @staticmethod - def _send_notification( - pipeline: Pipeline, - execution_id: Optional[str] = None, - error_message: Optional[str] = None, - ) -> None: - """Sends a notification for the pipeline. - Args: - pipeline (Pipeline): Pipeline to send notification for - - Returns: - None - """ - pipeline_notification = PipelineNotification( - pipeline=pipeline, execution_id=execution_id, error_message=error_message - ) - pipeline_notification.send() - - @staticmethod - def update_pipeline( - pipeline_guid: Optional[str], - status: tuple[str, str], - is_active: Optional[bool] = None, - execution_id: Optional[str] = None, - error_message: Optional[str] = None, - is_end: bool = False, - ) -> None: - if not pipeline_guid: - return - # Skip check if we are enabling an inactive pipeline - check_active = not is_active - pipeline: Pipeline = PipelineProcessor.fetch_pipeline( - pipeline_id=pipeline_guid, check_active=check_active - ) - pipeline = PipelineProcessor._update_pipeline_status( - pipeline=pipeline, is_end=is_end, status=status, is_active=is_active - ) - PipelineProcessor._send_notification( - pipeline=pipeline, execution_id=execution_id, error_message=error_message - ) - logger.info(f"Updated pipeline {pipeline_guid} status: {status}") diff --git a/backend/pipeline/public_api_urls.py b/backend/pipeline/public_api_urls.py deleted file mode 100644 index 6384ab48e..000000000 --- a/backend/pipeline/public_api_urls.py +++ /dev/null @@ -1,15 +0,0 @@ -from django.urls import re_path -from pipeline.piepline_api_execution_views import PipelineApiExecution -from rest_framework.urlpatterns import format_suffix_patterns - -execute = PipelineApiExecution.as_view() - -urlpatterns = format_suffix_patterns( - [ - re_path( - r"^api/(?P[\w-]+)/(?P[\w-]+)/?$", - execute, - name="pipeline_api_deployment_execution", - ), - ] -) diff --git a/backend/pipeline/serializers/crud.py b/backend/pipeline/serializers/crud.py deleted file mode 100644 index bc4c0cc61..000000000 --- a/backend/pipeline/serializers/crud.py +++ /dev/null @@ -1,189 +0,0 @@ -import logging -from collections import OrderedDict -from typing import Any, Optional - -from connector.connector_instance_helper import ConnectorInstanceHelper -from connector.models import ConnectorInstance -from connector_processor.connector_processor import ConnectorProcessor -from croniter import croniter -from django.conf import settings -from pipeline.constants import PipelineConstants as PC -from pipeline.constants import PipelineKey as PK -from pipeline.models import Pipeline -from rest_framework import serializers -from rest_framework.serializers import SerializerMethodField -from scheduler.helper import SchedulerHelper -from utils.serializer_utils import SerializerUtils -from workflow_manager.endpoint.models import WorkflowEndpoint - -from backend.serializers import AuditSerializer -from unstract.connectors.connectorkit import Connectorkit - -logger = logging.getLogger(__name__) -DEPLOYMENT_ENDPOINT = settings.API_DEPLOYMENT_PATH_PREFIX + "/pipeline" - - -class PipelineSerializer(AuditSerializer): - - api_endpoint = SerializerMethodField() - - class Meta: - model = Pipeline - fields = "__all__" - - def validate_cron_string(self, value: Optional[str] = None) -> Optional[str]: - """Validate the cron string provided in the serializer data. - - This method is called internally by the serializer to ensure that - the cron string is well-formed and adheres to the correct format. - If the cron string is valid, it is returned. If the string is None - or empty, it returns None. If the string is invalid, a - ValidationError is raised. - - Args: - value (Optional[str], optional): The cron string to validate. - Defaults to None. - - Raises: - serializers.ValidationError: Raised if the cron string is - not in a valid format. - - Returns: - Optional[str]: The validated cron string if it is valid, - otherwise None. - """ - if value is None: - return None - cron_string = value.strip() - # Check if the string is empty - if not cron_string: - return None - - # Validate the cron string - try: - croniter(cron_string) - except Exception as error: - logger.error(f"Invalid cron string '{cron_string}': {error}") - raise serializers.ValidationError("Invalid cron string format.") - - # Check if the frequency is less than 1 hour - cron_parts = cron_string.split() - minute_field = cron_parts[0] - if minute_field == "*" or any(char in minute_field for char in [",", "-", "/"]): - raise serializers.ValidationError( - "Cron schedule can not be more than once per hour. Please provide a " - "cron schedule to run at an hourly or less frequent interval." - ) - - return cron_string - - def get_api_endpoint(self, instance: Pipeline): - """Retrieve the API endpoint URL for a given Pipeline instance. - - This method is an internal serializer call that fetches the - `api_endpoint` property from the provided Pipeline instance. - - Args: - instance (Pipeline): The Pipeline instance for which the API - endpoint URL is being retrieved. - - Returns: - str: The API endpoint URL associated with the Pipeline instance. - """ - return instance.api_endpoint - - def create(self, validated_data: dict[str, Any]) -> Any: - # TODO: Deduce pipeline type based on WF? - validated_data[PK.ACTIVE] = True - return super().create(validated_data) - - def save(self, **kwargs: Any) -> Pipeline: - if PK.CRON_STRING in self.validated_data: - if self.validated_data[PK.CRON_STRING]: - self.validated_data[PK.SCHEDULED] = True - else: - self.validated_data[PK.SCHEDULED] = False - pipeline: Pipeline = super().save(**kwargs) - if pipeline.cron_string is None: - SchedulerHelper.remove_job(pipeline_id=str(pipeline.id)) - else: - SchedulerHelper.add_or_update_job(pipeline) - return pipeline - - def _get_name_and_icon(self, connectors: list[Any], connector_id: Any) -> Any: - for obj in connectors: - if obj["id"] == connector_id: - return obj["name"], obj["icon"] - return PC.NOT_CONFIGURED, None - - def _add_connector_data( - self, - repr: OrderedDict[str, Any], - connector_instance_list: list[Any], - connectors: list[Any], - ) -> OrderedDict[str, Any]: - """Adds connector Input/Output data. - - Args: - sef (_type_): _description_ - repr (OrderedDict[str, Any]): _description_ - - Returns: - OrderedDict[str, Any]: _description_ - """ - repr[PC.SOURCE_NAME] = PC.NOT_CONFIGURED - repr[PC.DESTINATION_NAME] = PC.NOT_CONFIGURED - for instance in connector_instance_list: - if instance.connector_type == "INPUT": - repr[PC.SOURCE_NAME], repr[PC.SOURCE_ICON] = self._get_name_and_icon( - connectors=connectors, - connector_id=instance.connector_id, - ) - if instance.connector_type == "OUTPUT": - repr[PC.DESTINATION_NAME], repr[PC.DESTINATION_ICON] = ( - self._get_name_and_icon( - connectors=connectors, - connector_id=instance.connector_id, - ) - ) - if repr[PC.DESTINATION_NAME] == PC.NOT_CONFIGURED: - try: - check_manual_review = WorkflowEndpoint.objects.get( - workflow=instance.workflow, - endpoint_type=WorkflowEndpoint.EndpointType.DESTINATION, - connection_type=WorkflowEndpoint.ConnectionType.MANUALREVIEW, - ) - if check_manual_review: - repr[PC.DESTINATION_NAME] = "Manual Review" - except Exception as ex: - logger.debug(f"Not a Manual review destination: {ex}") - - return repr - - def to_representation(self, instance: Pipeline) -> OrderedDict[str, Any]: - """To set Source, Destination & Agency for Pipelines.""" - repr: OrderedDict[str, Any] = super().to_representation(instance) - - connector_kit = Connectorkit() - connectors = connector_kit.get_connectors_list() - - if SerializerUtils.check_context_for_GET_or_POST(context=self.context): - workflow = instance.workflow - connector_instance_list = ConnectorInstanceHelper.get_input_output_connector_instances_by_workflow( # noqa - workflow.id - ) - repr[PK.WORKFLOW_ID] = workflow.id - repr[PK.WORKFLOW_NAME] = workflow.workflow_name - repr[PK.CRON_STRING] = repr.pop(PK.CRON_STRING) - repr = self._add_connector_data( - repr=repr, - connector_instance_list=connector_instance_list, - connectors=connectors, - ) - - return repr - - def get_connector_data(self, connector: ConnectorInstance, key: str) -> Any: - return ConnectorProcessor.get_connector_data_with_key( - connector.connector_id, key - ) diff --git a/backend/pipeline/serializers/execute.py b/backend/pipeline/serializers/execute.py deleted file mode 100644 index cc52e3cfd..000000000 --- a/backend/pipeline/serializers/execute.py +++ /dev/null @@ -1,24 +0,0 @@ -import logging - -from pipeline.models import Pipeline -from rest_framework import serializers - -logger = logging.getLogger(__name__) - - -class PipelineExecuteSerializer(serializers.Serializer): - # TODO: Add pipeline as a read_only related field - pipeline_id = serializers.UUIDField() - execution_id = serializers.UUIDField(required=False) - - def validate_pipeline_id(self, value: str) -> str: - try: - Pipeline.objects.get(pk=value) - except Pipeline.DoesNotExist: - raise serializers.ValidationError("Invalid pipeline ID") - return value - - -class DateRangeSerializer(serializers.Serializer): - start_date = serializers.DateTimeField(required=False) - end_date = serializers.DateTimeField(required=False) diff --git a/backend/pipeline/serializers/update.py b/backend/pipeline/serializers/update.py deleted file mode 100644 index cb23a04df..000000000 --- a/backend/pipeline/serializers/update.py +++ /dev/null @@ -1,14 +0,0 @@ -from pipeline.models import Pipeline -from rest_framework import serializers - - -class PipelineUpdateSerializer(serializers.Serializer): - pipeline_id = serializers.UUIDField(required=True) - active = serializers.BooleanField(required=True) - - def validate_pipeline_id(self, value: str) -> str: - try: - Pipeline.objects.get(pk=value) - except Pipeline.DoesNotExist: - raise serializers.ValidationError("Invalid pipeline ID") - return value diff --git a/backend/pipeline/urls.py b/backend/pipeline/urls.py deleted file mode 100644 index da3752fa1..000000000 --- a/backend/pipeline/urls.py +++ /dev/null @@ -1,52 +0,0 @@ -from django.urls import path -from pipeline.constants import PipelineURL -from pipeline.execution_view import PipelineExecutionViewSet -from pipeline.views import PipelineViewSet -from rest_framework.urlpatterns import format_suffix_patterns - -pipeline_list = PipelineViewSet.as_view( - { - "get": "list", - "post": "create", - } -) -execution_list = PipelineExecutionViewSet.as_view( - { - "get": "list", - } -) -pipeline_detail = PipelineViewSet.as_view( - { - "get": "retrieve", - "put": "update", - "patch": "partial_update", - "delete": "destroy", - } -) - -download_postman_collection = PipelineViewSet.as_view( - { - "get": PipelineViewSet.download_postman_collection.__name__, - } -) - -pipeline_execute = PipelineViewSet.as_view({"post": "execute"}) - - -urlpatterns = format_suffix_patterns( - [ - path("pipeline/", pipeline_list, name=PipelineURL.LIST), - path("pipeline//", pipeline_detail, name=PipelineURL.DETAIL), - path( - "pipeline//executions/", - execution_list, - name=PipelineURL.EXECUTIONS, - ), - path("pipeline/execute/", pipeline_execute, name=PipelineURL.EXECUTE), - path( - "pipeline/api/postman_collection//", - download_postman_collection, - name="download_pipeline_postman_collection", - ), - ] -) diff --git a/backend/pipeline/views.py b/backend/pipeline/views.py deleted file mode 100644 index 09e9bd204..000000000 --- a/backend/pipeline/views.py +++ /dev/null @@ -1,115 +0,0 @@ -import json -import logging -from typing import Optional - -from account.custom_exceptions import DuplicateData -from api.exceptions import NoActiveAPIKeyError -from api.key_helper import KeyHelper -from api.postman_collection.dto import PostmanCollection -from django.db import IntegrityError -from django.db.models import QuerySet -from django.http import HttpResponse -from permissions.permission import IsOwner -from pipeline.constants import PipelineConstants, PipelineErrors, PipelineExecutionKey -from pipeline.constants import PipelineKey as PK -from pipeline.manager import PipelineManager -from pipeline.models import Pipeline -from pipeline.pipeline_processor import PipelineProcessor -from pipeline.serializers.crud import PipelineSerializer -from pipeline.serializers.execute import PipelineExecuteSerializer as ExecuteSerializer -from rest_framework import serializers, status, viewsets -from rest_framework.decorators import action -from rest_framework.request import Request -from rest_framework.response import Response -from rest_framework.versioning import URLPathVersioning -from scheduler.helper import SchedulerHelper - -logger = logging.getLogger(__name__) - - -class PipelineViewSet(viewsets.ModelViewSet): - versioning_class = URLPathVersioning - queryset = Pipeline.objects.all() - permission_classes = [IsOwner] - serializer_class = PipelineSerializer - - def get_queryset(self) -> Optional[QuerySet]: - type = self.request.query_params.get(PipelineConstants.TYPE) - if type is not None: - queryset = Pipeline.objects.filter( - created_by=self.request.user, pipeline_type=type - ) - return queryset - elif type is None: - queryset = Pipeline.objects.filter(created_by=self.request.user) - return queryset - - def get_serializer_class(self) -> serializers.Serializer: - if self.action == "execute": - return ExecuteSerializer - else: - return PipelineSerializer - - # TODO: Refactor to perform an action with explicit arguments - # For eg, passing pipeline ID and with_log=False -> executes pipeline - # For FE however we call the same API twice - # (first call generates execution ID) - def execute(self, request: Request) -> Response: - serializer: ExecuteSerializer = self.get_serializer(data=request.data) - serializer.is_valid(raise_exception=True) - execution_id = serializer.validated_data.get("execution_id", None) - pipeline_id = serializer.validated_data[PK.PIPELINE_ID] - - execution = PipelineManager.execute_pipeline( - request=request, - pipeline_id=pipeline_id, - execution_id=execution_id, - ) - pipeline: Pipeline = PipelineProcessor.fetch_pipeline(pipeline_id) - serializer = PipelineSerializer(pipeline) - response_data = { - PipelineExecutionKey.PIPELINE: serializer.data, - PipelineExecutionKey.EXECUTION: execution.data, - } - return Response(data=response_data, status=status.HTTP_200_OK) - - def create(self, request: Request) -> Response: - serializer = self.get_serializer(data=request.data) - serializer.is_valid(raise_exception=True) - try: - pipeline_instance = serializer.save() - # Create API key using the created instance - KeyHelper.create_api_key(pipeline_instance) - except IntegrityError: - raise DuplicateData( - f"{PipelineErrors.PIPELINE_EXISTS}, " f"{PipelineErrors.DUPLICATE_API}" - ) - return Response(data=serializer.data, status=status.HTTP_201_CREATED) - - def perform_destroy(self, instance: Pipeline) -> None: - pipeline_to_remove = str(instance.pk) - super().perform_destroy(instance) - return SchedulerHelper.remove_job(pipeline_to_remove) - - @action(detail=True, methods=["get"]) - def download_postman_collection( - self, request: Request, pk: Optional[str] = None - ) -> Response: - """Downloads a Postman Collection of the API deployment instance.""" - instance: Pipeline = self.get_object() - api_key_inst = instance.apikey_set.filter(is_active=True).first() - if not api_key_inst: - logger.error(f"No active API key set for pipeline {instance}") - raise NoActiveAPIKeyError(deployment_name=instance.pipeline_name) - - # Create a PostmanCollection for a Pipeline - postman_collection = PostmanCollection.create( - instance=instance, api_key=api_key_inst.api_key - ) - response = HttpResponse( - json.dumps(postman_collection.to_dict()), content_type="application/json" - ) - response["Content-Disposition"] = ( - f'attachment; filename="{instance.pipeline_name}.json"' - ) - return response diff --git a/backend/platform_settings/__init__.py b/backend/platform_settings/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/backend/platform_settings/admin.py b/backend/platform_settings/admin.py deleted file mode 100644 index 846f6b406..000000000 --- a/backend/platform_settings/admin.py +++ /dev/null @@ -1 +0,0 @@ -# Register your models here. diff --git a/backend/platform_settings/apps.py b/backend/platform_settings/apps.py deleted file mode 100644 index 1ecd0f245..000000000 --- a/backend/platform_settings/apps.py +++ /dev/null @@ -1,6 +0,0 @@ -from django.apps import AppConfig - - -class PlatformSettingsConfig(AppConfig): - default_auto_field = "django.db.models.BigAutoField" - name = "platform_settings" diff --git a/backend/platform_settings/constants.py b/backend/platform_settings/constants.py deleted file mode 100644 index 29ba63684..000000000 --- a/backend/platform_settings/constants.py +++ /dev/null @@ -1,14 +0,0 @@ -class PlatformServiceConstants: - IS_ACTIVE = "is_active" - KEY = "key" - ORGANIZATION = "organization" - ID = "id" - ACTIVATE = "ACTIVATE" - DEACTIVATE = "DEACTIVATE" - ACTION = "action" - KEY_NAME = "key_name" - - -class ErrorMessage: - KEY_EXIST = "Key name already exists" - DUPLICATE_API = "It appears that a duplicate call may have been made." diff --git a/backend/platform_settings/exceptions.py b/backend/platform_settings/exceptions.py deleted file mode 100644 index 538b270fd..000000000 --- a/backend/platform_settings/exceptions.py +++ /dev/null @@ -1,49 +0,0 @@ -from typing import Optional - -from rest_framework.exceptions import APIException - - -class InternalServiceError(APIException): - status_code = 500 - default_detail = "Internal error occurred while platform key operations." - - -class UserForbidden(APIException): - status_code = 403 - default_detail = ( - "User is forbidden from performing this action. Please contact admin" - ) - - -class KeyCountExceeded(APIException): - status_code = 403 - default_detail = ( - "Maximum key count is exceeded. Please delete one before generation." - ) - - -class FoundActiveKey(APIException): - status_code = 403 - default_detail = "Only one active key allowed at a time." - - -class ActiveKeyNotFound(APIException): - status_code = 404 - default_detail = "At least one active platform key should be available" - - -class InvalidRequest(APIException): - status_code = 401 - default_detail = "Invalid Request" - - -class DuplicateData(APIException): - status_code = 400 - default_detail = "Duplicate Data" - - def __init__(self, detail: Optional[str] = None, code: Optional[int] = None): - if detail is not None: - self.detail = detail - if code is not None: - self.code = code - super().__init__(detail, code) diff --git a/backend/platform_settings/migrations/__init__.py b/backend/platform_settings/migrations/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/backend/platform_settings/models.py b/backend/platform_settings/models.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/backend/platform_settings/platform_auth_helper.py b/backend/platform_settings/platform_auth_helper.py deleted file mode 100644 index c3d7d2d4b..000000000 --- a/backend/platform_settings/platform_auth_helper.py +++ /dev/null @@ -1,51 +0,0 @@ -import logging - -from account.authentication_controller import AuthenticationController -from account.models import Organization, PlatformKey, User -from platform_settings.exceptions import KeyCountExceeded, UserForbidden -from tenant_account.models import OrganizationMember - -PLATFORM_KEY_COUNT = 2 - -logger = logging.getLogger(__name__) - - -class PlatformAuthHelper: - """Class to hold helper functions for Platform settings authentication.""" - - @staticmethod - def validate_user_role(user: User) -> None: - """This method validates if the logged in user has admin role for - performing appropriate actions. - - Args: - user (User): Logged in user from context - """ - auth_controller = AuthenticationController() - try: - member: OrganizationMember = ( - auth_controller.get_organization_members_by_user(user=user) - ) - except Exception as error: - logger.error( - f"Error occurred while fetching organization for user : {error}" - ) - raise error - if not auth_controller.is_admin_by_role(member.role): - logger.error("User is not having right access to perform this operation.") - raise UserForbidden() - else: - pass - - @staticmethod - def validate_token_count(organization: Organization) -> None: - if ( - PlatformKey.objects.filter(organization=organization).count() - >= PLATFORM_KEY_COUNT - ): - logger.error( - f"Only {PLATFORM_KEY_COUNT} keys are support at a time. Count exceeded." - ) - raise KeyCountExceeded() - else: - pass diff --git a/backend/platform_settings/platform_auth_service.py b/backend/platform_settings/platform_auth_service.py deleted file mode 100644 index 318c23591..000000000 --- a/backend/platform_settings/platform_auth_service.py +++ /dev/null @@ -1,242 +0,0 @@ -import logging -import uuid -from typing import Any, Optional - -from account.models import Organization, PlatformKey, User -from django.db import IntegrityError, connection -from django_tenants.utils import get_tenant_model -from platform_settings.exceptions import ( - ActiveKeyNotFound, - DuplicateData, - InternalServiceError, -) -from tenant_account.constants import ErrorMessage, PlatformServiceConstants - -logger = logging.getLogger(__name__) - - -class PlatformAuthenticationService: - """Service class to hold Platform service authentication and validation. - - Supports generation, refresh, revoke and toggle of active keys. - """ - - @staticmethod - def generate_platform_key( - is_active: bool, - key_name: str, - user: User, - organization: Optional[Organization] = None, - ) -> dict[str, Any]: - """Method to support generation of new platform key. Throws error when - maximum count is exceeded. Forbids for user other than admin - permission. - - Args: - key_name (str): Value of the key - is_active (bool): By default the key is False - user (User): User object representing the user generating the key - organization (Optional[Organization], optional): - Org the key belongs to. Defaults to None. - - Returns: - dict[str, Any]: - A dictionary containing the generated platform key details, - including the id, key name, and key value. - Raises: - DuplicateData: If a platform key with the same key name - already exists for the organization. - InternalServiceError: If an internal error occurs while - generating the platform key. - """ - organization = organization or connection.tenant - if not organization: - raise InternalServiceError("No valid organization provided") - try: - # TODO : Add encryption to Platform keys - # id is added here to avoid passing of keys in transactions. - platform_key: PlatformKey = PlatformKey( - id=str(uuid.uuid4()), - key=str(uuid.uuid4()), - is_active=is_active, - organization=organization, - key_name=key_name, - created_by=user, - modified_by=user, - ) - platform_key.save() - result: dict[str, Any] = {} - result[PlatformServiceConstants.ID] = platform_key.id - result[PlatformServiceConstants.KEY_NAME] = platform_key.key_name - result[PlatformServiceConstants.KEY] = platform_key.key - - logger.info(f"platform_key is generated for {organization.id}") - return result - except IntegrityError as error: - logger.error(f"Integrity error - failed to generate platform key : {error}") - raise DuplicateData( - f"{ErrorMessage.KEY_EXIST}, \ - {ErrorMessage.DUPLICATE_API}" - ) - except Exception as error: - logger.error(f"Failed to generate platform key : {error}") - raise InternalServiceError() - - @staticmethod - def delete_platform_key(id: str) -> None: - """This is a delete operation. Use this function only if you know what - you are doing. - - Args: - id (str): _description_ - - Raises: - error: _description_ - """ - try: - platform_key: PlatformKey = PlatformKey.objects.get(pk=id) - platform_key.delete() - logger.info(f"platform_key {id} is deleted") - except IntegrityError as error: - logger.error(f"Failed to delete platform key : {error}") - raise DuplicateData( - f"{ErrorMessage.KEY_EXIST}, \ - {ErrorMessage.DUPLICATE_API}" - ) - except Exception as error: - logger.error(f"Failed to delete platform key : {error}") - raise InternalServiceError() - - @staticmethod - def refresh_platform_key(id: str, user: User) -> dict[str, Any]: - """Method to refresh a platform key. - - Args: - id (str): Unique id of the key to be refreshed - new_key (str): Value to be updated. - - Raises: - error: IntegrityError - """ - try: - result: dict[str, Any] = {} - platform_key: PlatformKey = PlatformKey.objects.get(pk=id) - platform_key.key = str(uuid.uuid4()) - platform_key.modified_by = user - platform_key.save() - result[PlatformServiceConstants.ID] = platform_key.id - result[PlatformServiceConstants.KEY_NAME] = platform_key.key_name - result[PlatformServiceConstants.KEY] = platform_key.key - - logger.info(f"platform_key {id} is updated by user {user.id}") - return result - except IntegrityError as error: - logger.error(f"Integrity error - failed to refresh platform key : {error}") - raise DuplicateData( - f"{ErrorMessage.KEY_EXIST}, \ - {ErrorMessage.DUPLICATE_API}" - ) - except Exception as error: - logger.error(f"Failed to refresh platform key : {error}") - raise InternalServiceError() - - @staticmethod - def toggle_platform_key_status( - platform_key: PlatformKey, action: str, user: User - ) -> None: - """Method to activate/deactivate a platform key. Only one active key is - allowed at a time. On change or setting, other keys are deactivated. - - Args: - id (str): Id of the key to be toggled. - action (str): activate/deactivate - - Raises: - error: IntegrityError - """ - try: - organization = connection.tenant - platform_key.modified_by = user - if action == PlatformServiceConstants.ACTIVATE: - active_keys: list[PlatformKey] = PlatformKey.objects.filter( - is_active=True, organization=organization - ).all() - # Deactivates all keys - for key in active_keys: - key.is_active = False - key.modified_by = user - key.save() - # Activates the chosen key. - platform_key.is_active = True - platform_key.save() - if action == PlatformServiceConstants.DEACTIVATE: - platform_key.is_active = False - platform_key.save() - except IntegrityError as error: - logger.error( - "IntegrityError - Failed to activate/deactivate " - f"platform key : {error}" - ) - raise DuplicateData( - f"{ErrorMessage.KEY_EXIST}, \ - {ErrorMessage.DUPLICATE_API}" - ) - except Exception as error: - logger.error(f"Failed to activate/deactivate platform key : {error}") - raise InternalServiceError() - - @staticmethod - def list_platform_key_ids() -> list[PlatformKey]: - """Method to fetch list of platform keys unique ids for internal usage. - - Returns: - Any: List of platform keys. - """ - try: - organization_id = connection.tenant.id - platform_keys: list[PlatformKey] = PlatformKey.objects.filter( - organization=organization_id - ) - return platform_keys - except Exception as error: - logger.error(f"Failed to fetch platform key ids : {error}") - raise InternalServiceError() - - @staticmethod - def fetch_platform_key_id() -> Any: - """Method to fetch list of platform keys unique ids for internal usage. - - Returns: - Any: List of platform keys. - """ - try: - platform_key: list[PlatformKey] = PlatformKey.objects.all() - return platform_key - except Exception as error: - logger.error(f"Failed to fetch platform key ids : {error}") - raise InternalServiceError() - - @staticmethod - def get_active_platform_key( - organization_id: Optional[str] = None, - ) -> PlatformKey: - """Method to fetch active key. - - Considering only one active key is allowed at a time - Returns: - Any: platformKey. - """ - try: - organization_id = organization_id or connection.tenant.schema_name - organization: Organization = get_tenant_model().objects.get( - schema_name=organization_id - ) - platform_key: PlatformKey = PlatformKey.objects.get( - organization=organization, is_active=True - ) - return platform_key - except PlatformKey.DoesNotExist: - raise ActiveKeyNotFound() - except Exception as error: - logger.error(f"Failed to fetch platform key : {error}") - raise InternalServiceError() diff --git a/backend/platform_settings/serializers.py b/backend/platform_settings/serializers.py deleted file mode 100644 index 603883c66..000000000 --- a/backend/platform_settings/serializers.py +++ /dev/null @@ -1,26 +0,0 @@ -# serializers.py - -from account.models import PlatformKey -from rest_framework import serializers - -from backend.serializers import AuditSerializer - - -class PlatformKeySerializer(AuditSerializer): - class Meta: - model = PlatformKey - fields = "__all__" - - -class PlatformKeyGenerateSerializer(serializers.Serializer): - # Adjust these fields based on your actual serializer - is_active = serializers.BooleanField() - - key_name = serializers.CharField() - - -class PlatformKeyIDSerializer(serializers.Serializer): - id = serializers.CharField() - key_name = serializers.CharField() - key = serializers.CharField() - is_active = serializers.BooleanField() diff --git a/backend/platform_settings/tests.py b/backend/platform_settings/tests.py deleted file mode 100644 index a39b155ac..000000000 --- a/backend/platform_settings/tests.py +++ /dev/null @@ -1 +0,0 @@ -# Create your tests here. diff --git a/backend/platform_settings/urls.py b/backend/platform_settings/urls.py deleted file mode 100644 index feb1f5cc1..000000000 --- a/backend/platform_settings/urls.py +++ /dev/null @@ -1,26 +0,0 @@ -from django.urls import path -from rest_framework.urlpatterns import format_suffix_patterns - -from .views import PlatformKeyViewSet - -platform_key_list = PlatformKeyViewSet.as_view( - {"post": "create", "put": "refresh", "get": "list"} -) -platform_key_update = PlatformKeyViewSet.as_view( - {"put": "toggle_platform_key", "delete": "destroy"} -) - -urlpatterns = format_suffix_patterns( - [ - path( - "keys/", - platform_key_list, - name="generate_platform_key", - ), - path( - "keys//", - platform_key_update, - name="update_platform_key", - ), - ] -) diff --git a/backend/platform_settings/views.py b/backend/platform_settings/views.py deleted file mode 100644 index 8745eafd1..000000000 --- a/backend/platform_settings/views.py +++ /dev/null @@ -1,123 +0,0 @@ -# views.py - -import logging -from typing import Any - -from account.models import Organization, PlatformKey -from django.db import connection -from platform_settings.constants import PlatformServiceConstants -from platform_settings.platform_auth_helper import PlatformAuthHelper -from platform_settings.platform_auth_service import PlatformAuthenticationService -from platform_settings.serializers import ( - PlatformKeyGenerateSerializer, - PlatformKeyIDSerializer, - PlatformKeySerializer, -) -from rest_framework import status, viewsets -from rest_framework.request import Request -from rest_framework.response import Response - -logger = logging.getLogger(__name__) - - -class PlatformKeyViewSet(viewsets.ModelViewSet): - queryset = PlatformKey.objects.all() - serializer_class = PlatformKeySerializer - - def validate_user_role(func: Any) -> Any: - def wrapper( - self: Any, - request: Request, - *args: tuple[Any], - **kwargs: dict[str, Any], - ) -> Any: - PlatformAuthHelper.validate_user_role(request.user) - return func(self, request, *args, **kwargs) - - return wrapper - - @validate_user_role - def list( - self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any] - ) -> Response: - platform_key_ids = PlatformAuthenticationService.list_platform_key_ids() - serializer = PlatformKeyIDSerializer(platform_key_ids, many=True) - return Response( - status=status.HTTP_200_OK, - data=serializer.data, - ) - - @validate_user_role - def refresh( - self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any] - ) -> Response: - """API Endpoint for refreshing platform keys.""" - id = request.data.get(PlatformServiceConstants.ID) - if not id: - return Response( - status=status.HTTP_400_BAD_REQUEST, - data={ - "message": "validation error", - "errors": "Mandatory fields missing", - }, - ) - platform_key = PlatformAuthenticationService.refresh_platform_key( - id=id, user=request.user - ) - return Response( - status=status.HTTP_201_CREATED, - data=platform_key, - ) - - @validate_user_role - def destroy( - self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any] - ) -> Response: - instance = self.get_object() - instance.delete() - return Response( - status=status.HTTP_204_NO_CONTENT, - data={"message": "Platform key deleted successfully"}, - ) - - @validate_user_role - def toggle_platform_key( - self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any] - ) -> Response: - instance = self.get_object() - action = request.data.get(PlatformServiceConstants.ACTION) - if not action: - return Response( - status=status.HTTP_400_BAD_REQUEST, - data={ - "message": "validation error", - "errors": "Mandatory fields missing", - }, - ) - PlatformAuthenticationService.toggle_platform_key_status( - platform_key=instance, action=action, user=request.user - ) - return Response( - status=status.HTTP_201_CREATED, - data={"message": "Platform key toggled successfully"}, - ) - - @validate_user_role - def create(self, request: Request) -> Response: - serializer = PlatformKeyGenerateSerializer(data=request.data) - serializer.is_valid(raise_exception=True) - is_active = request.data.get(PlatformServiceConstants.IS_ACTIVE) - key_name = request.data.get(PlatformServiceConstants.KEY_NAME) - - organization: Organization = connection.tenant - - PlatformAuthHelper.validate_token_count(organization=organization) - - platform_key = PlatformAuthenticationService.generate_platform_key( - is_active=is_active, key_name=key_name, user=request.user - ) - serialized_data = self.serializer_class(platform_key).data - return Response( - status=status.HTTP_201_CREATED, - data=serialized_data, - ) diff --git a/backend/tenant_account/__init__.py b/backend/tenant_account/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/backend/tenant_account/admin.py b/backend/tenant_account/admin.py deleted file mode 100644 index 846f6b406..000000000 --- a/backend/tenant_account/admin.py +++ /dev/null @@ -1 +0,0 @@ -# Register your models here. diff --git a/backend/tenant_account/apps.py b/backend/tenant_account/apps.py deleted file mode 100644 index 8549865f7..000000000 --- a/backend/tenant_account/apps.py +++ /dev/null @@ -1,6 +0,0 @@ -from django.apps import AppConfig - - -class TenantAccountConfig(AppConfig): - default_auto_field = "django.db.models.BigAutoField" - name = "tenant_account" diff --git a/backend/tenant_account/constants.py b/backend/tenant_account/constants.py deleted file mode 100644 index 29ba63684..000000000 --- a/backend/tenant_account/constants.py +++ /dev/null @@ -1,14 +0,0 @@ -class PlatformServiceConstants: - IS_ACTIVE = "is_active" - KEY = "key" - ORGANIZATION = "organization" - ID = "id" - ACTIVATE = "ACTIVATE" - DEACTIVATE = "DEACTIVATE" - ACTION = "action" - KEY_NAME = "key_name" - - -class ErrorMessage: - KEY_EXIST = "Key name already exists" - DUPLICATE_API = "It appears that a duplicate call may have been made." diff --git a/backend/tenant_account/dto.py b/backend/tenant_account/dto.py deleted file mode 100644 index 16c37ea58..000000000 --- a/backend/tenant_account/dto.py +++ /dev/null @@ -1,15 +0,0 @@ -from dataclasses import dataclass - - -@dataclass -class OrganizationLoginResponse: - name: str - display_name: str - organization_id: str - created_at: str - - -@dataclass -class ResetUserPasswordDto: - status: bool - message: str diff --git a/backend/tenant_account/enums.py b/backend/tenant_account/enums.py deleted file mode 100644 index d8209ec2d..000000000 --- a/backend/tenant_account/enums.py +++ /dev/null @@ -1,6 +0,0 @@ -from enum import Enum - - -class UserRole(Enum): - USER = "user" - ADMIN = "admin" diff --git a/backend/tenant_account/invitation_urls.py b/backend/tenant_account/invitation_urls.py deleted file mode 100644 index 2176aaa6f..000000000 --- a/backend/tenant_account/invitation_urls.py +++ /dev/null @@ -1,20 +0,0 @@ -from django.urls import path -from tenant_account.invitation_views import InvitationViewSet - -invitation_list = InvitationViewSet.as_view( - { - "get": InvitationViewSet.list_invitations.__name__, - } -) - -invitation_details = InvitationViewSet.as_view( - { - "delete": InvitationViewSet.delete_invitation.__name__, - } -) - - -urlpatterns = [ - path("", invitation_list, name="invitation_list"), - path("/", invitation_details, name="invitation_details"), -] diff --git a/backend/tenant_account/invitation_views.py b/backend/tenant_account/invitation_views.py deleted file mode 100644 index e6b0aeb5c..000000000 --- a/backend/tenant_account/invitation_views.py +++ /dev/null @@ -1,46 +0,0 @@ -import logging - -from account.authentication_controller import AuthenticationController -from account.dto import MemberInvitation -from rest_framework import status, viewsets -from rest_framework.decorators import action -from rest_framework.request import Request -from rest_framework.response import Response -from tenant_account.serializer import ListInvitationsResponseSerializer -from utils.user_session import UserSessionUtils - -Logger = logging.getLogger(__name__) - - -class InvitationViewSet(viewsets.ViewSet): - @action(detail=False, methods=["GET"]) - def list_invitations(self, request: Request) -> Response: - auth_controller = AuthenticationController() - invitations: list[MemberInvitation] = auth_controller.get_user_invitations( - organization_id=UserSessionUtils.get_organization_id(request), - ) - serialized_members = ListInvitationsResponseSerializer( - invitations, many=True - ).data - return Response( - status=status.HTTP_200_OK, - data={"message": "success", "members": serialized_members}, - ) - - @action(detail=False, methods=["DELETE"]) - def delete_invitation(self, request: Request, id: str) -> Response: - auth_controller = AuthenticationController() - is_deleted: bool = auth_controller.delete_user_invitation( - organization_id=UserSessionUtils.get_organization_id(request), - invitation_id=id, - ) - if is_deleted: - return Response( - status=status.HTTP_204_NO_CONTENT, - data={"status": "success", "message": "success"}, - ) - else: - return Response( - status=status.HTTP_404_NOT_FOUND, - data={"status": "failed", "message": "failed"}, - ) diff --git a/backend/tenant_account/migrations/0001_initial.py b/backend/tenant_account/migrations/0001_initial.py deleted file mode 100644 index 9cf3e1e69..000000000 --- a/backend/tenant_account/migrations/0001_initial.py +++ /dev/null @@ -1,45 +0,0 @@ -# Generated by Django 4.2.1 on 2023-07-18 15:34 - -import django.contrib.auth.models -import django.db.models.deletion -from django.conf import settings -from django.db import migrations, models - - -class Migration(migrations.Migration): - initial = True - - dependencies = [ - ("account", "0002_auto_20230718_1040"), - ] - - operations = [ - # Updated the name here as the 002, 0002 step is just name change - migrations.CreateModel( - name="OrganizationMember", - fields=[ - ( - "user_ptr", - models.OneToOneField( - auto_created=True, - on_delete=django.db.models.deletion.CASCADE, - parent_link=True, - primary_key=True, - serialize=False, - to=settings.AUTH_USER_MODEL, - ), - ), - # Added column which is used in 0002 here - ("role", models.CharField(default="admin")), - ], - options={ - "verbose_name": "user", - "verbose_name_plural": "users", - "abstract": False, - }, - bases=("account.user",), - managers=[ - ("objects", django.contrib.auth.models.UserManager()), - ], - ), - ] diff --git a/backend/tenant_account/migrations/0002_organizationmember_delete_user.py b/backend/tenant_account/migrations/0002_organizationmember_delete_user.py deleted file mode 100644 index 68b4a63a2..000000000 --- a/backend/tenant_account/migrations/0002_organizationmember_delete_user.py +++ /dev/null @@ -1,45 +0,0 @@ -# Generated by Django 4.2.1 on 2023-08-21 11:12 - -from django.db import migrations - - -class Migration(migrations.Migration): - dependencies = [ - ("account", "0002_auto_20230718_1040"), - ("tenant_account", "0001_initial"), - ] - - operations = [ - # # Commenting out here as this is taken care in 0001 - # migrations.CreateModel( - # name="OrganizationMember", - # fields=[ - # ( - # "user_ptr", - # models.OneToOneField( - # auto_created=True, - # on_delete=django.db.models.deletion.CASCADE, - # parent_link=True, - # primary_key=True, - # serialize=False, - # to=settings.AUTH_USER_MODEL, - # ), - # ), - # ("role", models.CharField(default="admin")), - # ], - # options={ - # "verbose_name": "user", - # "verbose_name_plural": "users", - # "abstract": False, - # }, - # bases=("account.user",), - # managers=[ - # ("objects", django.contrib.auth.models.UserManager()), - # ], - # ), - # # https://www.geeksforgeeks.org/what-is-access-exclusive-lock-mode-in-postgreysql/ - # # commenting drop table to ignore AccesExclusive Lock - # migrations.DeleteModel( - # name="User", - # ), - ] diff --git a/backend/tenant_account/migrations/0003_alter_organizationmember_options_and_more.py b/backend/tenant_account/migrations/0003_alter_organizationmember_options_and_more.py deleted file mode 100644 index 3a16c85e9..000000000 --- a/backend/tenant_account/migrations/0003_alter_organizationmember_options_and_more.py +++ /dev/null @@ -1,55 +0,0 @@ -# Generated by Django 4.2.1 on 2023-09-15 12:11 - -import django.db.models.deletion -from django.conf import settings -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - migrations.swappable_dependency(settings.AUTH_USER_MODEL), - ("tenant_account", "0002_organizationmember_delete_user"), - ] - - operations = [ - migrations.AlterModelOptions( - name="organizationmember", - options={}, - ), - migrations.AlterModelManagers( - name="organizationmember", - managers=[], - ), - migrations.RenameField( - model_name="organizationmember", - old_name="user_ptr", - new_name="user", - ), - migrations.AlterField( - model_name="organizationmember", - name="user", - field=models.OneToOneField( - default=None, - on_delete=django.db.models.deletion.CASCADE, - related_name="organization_member", - to=settings.AUTH_USER_MODEL, - ), - ), - migrations.AddField( - model_name="organizationmember", - name="member_id", - field=models.BigAutoField( - auto_created=True, - default=None, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - preserve_default=False, - ), - migrations.AlterField( - model_name="organizationmember", - name="role", - field=models.CharField(), - ), - ] diff --git a/backend/tenant_account/migrations/0004_alter_organizationmember_member_id.py b/backend/tenant_account/migrations/0004_alter_organizationmember_member_id.py deleted file mode 100644 index 9392131db..000000000 --- a/backend/tenant_account/migrations/0004_alter_organizationmember_member_id.py +++ /dev/null @@ -1,17 +0,0 @@ -# Generated by Django 4.2.1 on 2023-09-20 12:28 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("tenant_account", "0003_alter_organizationmember_options_and_more"), - ] - - operations = [ - migrations.AlterField( - model_name="organizationmember", - name="member_id", - field=models.BigAutoField(primary_key=True, serialize=False), - ), - ] diff --git a/backend/tenant_account/migrations/0005_organizationmember_is_onboarding_msg_and_more.py b/backend/tenant_account/migrations/0005_organizationmember_is_onboarding_msg_and_more.py deleted file mode 100644 index a748dfdd5..000000000 --- a/backend/tenant_account/migrations/0005_organizationmember_is_onboarding_msg_and_more.py +++ /dev/null @@ -1,29 +0,0 @@ -# Generated by Django 4.2.1 on 2024-04-30 08:05 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ("tenant_account", "0004_alter_organizationmember_member_id"), - ] - - operations = [ - migrations.AddField( - model_name="organizationmember", - name="is_login_onboarding_msg", - field=models.BooleanField( - db_comment="Flag to indicate whether the onboarding messages are shown to user", - default=True, - ), - ), - migrations.AddField( - model_name="organizationmember", - name="is_prompt_studio_onboarding_msg", - field=models.BooleanField( - db_comment="Flag to indicate whether the prompt studio messages are shown to user", - default=True, - ), - ), - ] diff --git a/backend/tenant_account/migrations/0006_alter_organizationmember_is_login_onboarding_msg_and_more.py b/backend/tenant_account/migrations/0006_alter_organizationmember_is_login_onboarding_msg_and_more.py deleted file mode 100644 index e74e522c8..000000000 --- a/backend/tenant_account/migrations/0006_alter_organizationmember_is_login_onboarding_msg_and_more.py +++ /dev/null @@ -1,29 +0,0 @@ -# Generated by Django 4.2.1 on 2024-05-08 12:53 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ("tenant_account", "0005_organizationmember_is_onboarding_msg_and_more"), - ] - - operations = [ - migrations.AlterField( - model_name="organizationmember", - name="is_login_onboarding_msg", - field=models.BooleanField( - db_comment="Flag to indicate whether the onboarding messages are shown", - default=True, - ), - ), - migrations.AlterField( - model_name="organizationmember", - name="is_prompt_studio_onboarding_msg", - field=models.BooleanField( - db_comment="Flag to indicate whether the prompt studio messages are shown", - default=True, - ), - ), - ] diff --git a/backend/tenant_account/migrations/__init__.py b/backend/tenant_account/migrations/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/backend/tenant_account/models.py b/backend/tenant_account/models.py deleted file mode 100644 index 2ac7f41e5..000000000 --- a/backend/tenant_account/models.py +++ /dev/null @@ -1,27 +0,0 @@ -from account.models import User -from django.db import models - - -class OrganizationMember(models.Model): - member_id = models.BigAutoField(primary_key=True) - user = models.OneToOneField( - User, - on_delete=models.CASCADE, - default=None, - related_name="organization_member", - ) - role = models.CharField() - is_login_onboarding_msg = models.BooleanField( - default=True, - db_comment="Flag to indicate whether the onboarding messages are shown", - ) - is_prompt_studio_onboarding_msg = models.BooleanField( - default=True, - db_comment="Flag to indicate whether the prompt studio messages are shown", - ) - - def __str__(self): # type: ignore - return ( - f"OrganizationMember(" - f"{self.member_id}, role: {self.role}, userId: {self.user.user_id})" - ) diff --git a/backend/tenant_account/organization_member_service.py b/backend/tenant_account/organization_member_service.py deleted file mode 100644 index 354fdeaeb..000000000 --- a/backend/tenant_account/organization_member_service.py +++ /dev/null @@ -1,117 +0,0 @@ -from typing import Optional - -from tenant_account.models import OrganizationMember -from utils.cache_service import CacheService - - -class OrganizationMemberService: - - @staticmethod - def get_user_by_email(email: str) -> Optional[OrganizationMember]: - try: - return OrganizationMember.objects.get(user__email=email) # type: ignore - except OrganizationMember.DoesNotExist: - return None - - @staticmethod - def get_user_by_user_id(user_id: str) -> Optional[OrganizationMember]: - try: - return OrganizationMember.objects.get(user__user_id=user_id) # type: ignore - except OrganizationMember.DoesNotExist: - return None - - @staticmethod - def get_user_by_id(id: str) -> Optional[OrganizationMember]: - try: - return OrganizationMember.objects.get(user=id) # type: ignore - except OrganizationMember.DoesNotExist: - return None - - @staticmethod - def delete_user(user: OrganizationMember) -> None: - """Delete a user from an organization. - - Parameters: - user (OrganizationMember): The user to delete. - """ - user.delete() - - @staticmethod - def remove_users_by_user_pks(user_pks: list[str]) -> None: - """Remove a users from an organization. - - Parameters: - user_pks (list[str]): The primary keys of the users to remove. - """ - OrganizationMember.objects.filter(user__in=user_pks).delete() - - @classmethod - def remove_user_by_user_id(cls, user_id: str) -> None: - """Remove a user from an organization. - - Parameters: - user_id (str): The user_id of the user to remove. - """ - user = cls.get_user_by_user_id(user_id) - if user: - cls.delete_user(user) - - @staticmethod - def get_organization_user_cache_key(user_id: str, organization_id: str) -> str: - """Get the cache key for a user in an organization. - - Parameters: - organization_id (str): The ID of the organization. - - Returns: - str: The cache key for a user in the organization. - """ - return f"user_organization:{user_id}:{organization_id}" - - @classmethod - def check_user_membership_in_organization_cache( - cls, user_id: str, organization_id: str - ) -> bool: - """Check if a user exists in an organization. - - Parameters: - user_id (str): The ID of the user to check. - organization_id (str): The ID of the organization to check. - - Returns: - bool: True if the user exists in the organization, False otherwise. - """ - user_organization_key = cls.get_organization_user_cache_key( - user_id, organization_id - ) - return CacheService.check_a_key_exist(user_organization_key) - - @classmethod - def set_user_membership_in_organization_cache( - cls, user_id: str, organization_id: str - ) -> None: - """Set a user's membership in an organization in the cache. - - Parameters: - user_id (str): The ID of the user. - organization_id (str): The ID of the organization. - """ - user_organization_key = cls.get_organization_user_cache_key( - user_id, organization_id - ) - CacheService.set_key(user_organization_key, {}) - - @classmethod - def remove_user_membership_in_organization_cache( - cls, user_id: str, organization_id: str - ) -> None: - """Remove a user's membership in an organization from the cache. - - Parameters: - user_id (str): The ID of the user. - organization_id (str): The ID of the organization. - """ - user_organization_key = cls.get_organization_user_cache_key( - user_id, organization_id - ) - CacheService.delete_a_key(user_organization_key) diff --git a/backend/tenant_account/serializer.py b/backend/tenant_account/serializer.py deleted file mode 100644 index 5ab84aa64..000000000 --- a/backend/tenant_account/serializer.py +++ /dev/null @@ -1,159 +0,0 @@ -from collections import OrderedDict -from typing import Any, Optional, Union, cast - -from account.constants import Common -from rest_framework import serializers -from rest_framework.exceptions import ValidationError -from rest_framework.serializers import ModelSerializer -from tenant_account.models import OrganizationMember - - -class OrganizationCallbackSerializer(serializers.Serializer): - id = serializers.CharField(required=False) - - -class OrganizationLoginResponseSerializer(serializers.Serializer): - name = serializers.CharField() - display_name = serializers.CharField() - organization_id = serializers.CharField() - created_at = serializers.CharField() - - -class UserInviteResponseSerializer(serializers.Serializer): - email = serializers.CharField(required=True) - status = serializers.CharField(required=True) - message = serializers.CharField(required=False) - - -class OrganizationMemberSerializer(serializers.ModelSerializer): - email = serializers.CharField(source="user.email", read_only=True) - id = serializers.CharField(source="user.id", read_only=True) - - class Meta: - model = OrganizationMember - fields = ("id", "email", "role") - - -class LimitedUserEmailListSerializer(serializers.ListSerializer): - def __init__(self, *args: Any, **kwargs: Any) -> None: - self.max_elements: int = kwargs.pop("max_elements", Common.MAX_EMAIL_IN_REQUEST) - super().__init__(*args, **kwargs) - - def validate(self, data: list[str]) -> Any: - if len(data) > self.max_elements: - raise ValidationError( - f"Exceeded maximum number of elements ({self.max_elements})" - ) - return data - - -class LimitedUserListSerializer(serializers.ListSerializer): - def __init__(self, *args: Any, **kwargs: Any) -> None: - self.max_elements: int = kwargs.pop("max_elements", Common.MAX_EMAIL_IN_REQUEST) - super().__init__(*args, **kwargs) - - def validate( - self, data: list[dict[str, Union[str, None]]] - ) -> list[dict[str, Union[str, None]]]: - if len(data) > self.max_elements: - raise ValidationError( - f"Exceeded maximum number of elements ({self.max_elements})" - ) - - for item in data: - if not isinstance(item, dict): - raise ValidationError("Each item in the list must be a dictionary.") - if "email" not in item: - raise ValidationError("Each item in the list must have 'email' key.") - if "role" not in item: - item["role"] = None - - return data - - -class InviteUserSerializer(serializers.Serializer): - users = LimitedUserListSerializer( - required=True, - child=serializers.DictField( - child=serializers.CharField(max_length=255, required=True), - required=False, # Make 'role' field optional - ), - max_elements=Common.MAX_EMAIL_IN_REQUEST, - ) - - def get_users( - self, validated_data: dict[str, Any] - ) -> list[dict[str, Union[str, None]]]: - return validated_data.get("users", []) - - -class RemoveUserFromOrganizationSerializer(serializers.Serializer): - emails = LimitedUserEmailListSerializer( - required=True, - child=serializers.EmailField(required=True), - max_elements=Common.MAX_EMAIL_IN_REQUEST, - ) - - def get_user_emails( - self, validated_data: dict[str, Union[list[str], None]] - ) -> list[str]: - return cast(list[str], validated_data.get(Common.USER_EMAILS, [])) - - -class ChangeUserRoleRequestSerializer(serializers.Serializer): - email = serializers.EmailField(required=True) - role = serializers.CharField(required=True) - - def get_user_email( - self, validated_data: dict[str, Union[str, None]] - ) -> Optional[str]: - return validated_data.get(Common.USER_EMAIL) - - def get_user_role( - self, validated_data: dict[str, Union[str, None]] - ) -> Optional[str]: - return validated_data.get(Common.USER_ROLE) - - -class DeleteInvitationRequestSerializer(serializers.Serializer): - id = serializers.EmailField(required=True) - - def get_id(self, validated_data: dict[str, Union[str, None]]) -> Optional[str]: - return validated_data.get(Common.ID) - - -class UserInfoSerializer(serializers.Serializer): - id = serializers.CharField() - email = serializers.CharField() - name = serializers.CharField() - display_name = serializers.CharField() - family_name = serializers.CharField() - picture = serializers.CharField() - - -class GetRolesResponseSerializer(serializers.Serializer): - id = serializers.CharField() - name = serializers.CharField() - description = serializers.CharField() - - def to_representation(self, instance: Any) -> OrderedDict[str, Any]: - data: OrderedDict[str, Any] = super().to_representation(instance) - return data - - -class ListInvitationsResponseSerializer(serializers.Serializer): - id = serializers.CharField() - email = serializers.CharField() - created_at = serializers.CharField() - expires_at = serializers.CharField() - - def to_representation(self, instance: Any) -> OrderedDict[str, Any]: - data: OrderedDict[str, Any] = super().to_representation(instance) - return data - - -class UpdateFlagSerializer(ModelSerializer): - - class Meta: - model = OrganizationMember - fields = ("is_login_onboarding_msg", "is_prompt_studio_onboarding_msg") diff --git a/backend/tenant_account/templates/land.html b/backend/tenant_account/templates/land.html deleted file mode 100644 index 8cec6bca6..000000000 --- a/backend/tenant_account/templates/land.html +++ /dev/null @@ -1,11 +0,0 @@ - - - - - ZipstackID Django App Example - - -

Welcome Guest

-

dsdasdaddddddddd

- - diff --git a/backend/tenant_account/tests.py b/backend/tenant_account/tests.py deleted file mode 100644 index a39b155ac..000000000 --- a/backend/tenant_account/tests.py +++ /dev/null @@ -1 +0,0 @@ -# Create your tests here. diff --git a/backend/tenant_account/urls.py b/backend/tenant_account/urls.py deleted file mode 100644 index 80146e1d0..000000000 --- a/backend/tenant_account/urls.py +++ /dev/null @@ -1,11 +0,0 @@ -from django.urls import include, path -from tenant_account import invitation_urls, users_urls -from tenant_account.views import get_organization, get_roles, reset_password - -urlpatterns = [ - path("roles", get_roles, name="roles"), - path("users/", include(users_urls)), - path("invitation/", include(invitation_urls)), - path("organization", get_organization, name="get_organization"), - path("reset_password", reset_password, name="reset_password"), -] diff --git a/backend/tenant_account/users_urls.py b/backend/tenant_account/users_urls.py deleted file mode 100644 index 08f02208e..000000000 --- a/backend/tenant_account/users_urls.py +++ /dev/null @@ -1,37 +0,0 @@ -from django.urls import path -from tenant_account.users_view import OrganizationUserViewSet - -organization_user_role = OrganizationUserViewSet.as_view( - { - "post": OrganizationUserViewSet.assign_organization_role_to_user.__name__, - "delete": OrganizationUserViewSet.remove_organization_role_from_user.__name__, - } -) - -user_profile = OrganizationUserViewSet.as_view( - { - "get": OrganizationUserViewSet.get_user_profile.__name__, - "put": OrganizationUserViewSet.update_flags.__name__, - } -) - -invite_user = OrganizationUserViewSet.as_view( - { - "post": OrganizationUserViewSet.invite_user.__name__, - } -) - -organization_users = OrganizationUserViewSet.as_view( - { - "get": OrganizationUserViewSet.get_organization_members.__name__, - "delete": OrganizationUserViewSet.remove_members_from_organization.__name__, - } -) - - -urlpatterns = [ - path("", organization_users, name="organization_user"), - path("profile/", user_profile, name="user_profile"), - path("role/", organization_user_role, name="organization_user_role"), - path("invite/", invite_user, name="invite_user"), -] diff --git a/backend/tenant_account/users_view.py b/backend/tenant_account/users_view.py deleted file mode 100644 index ebcebd14a..000000000 --- a/backend/tenant_account/users_view.py +++ /dev/null @@ -1,196 +0,0 @@ -import logging - -from account.authentication_controller import AuthenticationController -from account.exceptions import BadRequestException -from rest_framework import status, viewsets -from rest_framework.decorators import action -from rest_framework.request import Request -from rest_framework.response import Response -from tenant_account.models import OrganizationMember -from tenant_account.serializer import ( - ChangeUserRoleRequestSerializer, - InviteUserSerializer, - OrganizationMemberSerializer, - RemoveUserFromOrganizationSerializer, - UpdateFlagSerializer, - UserInfoSerializer, - UserInviteResponseSerializer, -) -from utils.user_session import UserSessionUtils - -Logger = logging.getLogger(__name__) - - -class OrganizationUserViewSet(viewsets.ViewSet): - @action(detail=False, methods=["POST"]) - def assign_organization_role_to_user(self, request: Request) -> Response: - serializer = ChangeUserRoleRequestSerializer(data=request.data) - serializer.is_valid(raise_exception=True) - user_email = serializer.get_user_email(serializer.validated_data) - role = serializer.get_user_role(serializer.validated_data) - if not (user_email and role): - raise BadRequestException - org_id: str = UserSessionUtils.get_organization_id(request) - auth_controller = AuthenticationController() - update_status = auth_controller.add_user_role( - request.user, org_id, user_email, role - ) - if update_status: - return Response( - status=status.HTTP_200_OK, - data={"status": "success", "message": "success"}, - ) - else: - return Response( - status=status.HTTP_400_BAD_REQUEST, - data={"status": "failed", "message": "failed"}, - ) - - @action(detail=False, methods=["DELETE"]) - def remove_organization_role_from_user(self, request: Request) -> Response: - serializer = ChangeUserRoleRequestSerializer(data=request.data) - serializer.is_valid(raise_exception=True) - user_email = serializer.get_user_email(serializer.validated_data) - role = serializer.get_user_role(serializer.validated_data) - if not (user_email and role): - raise BadRequestException - org_id: str = UserSessionUtils.get_organization_id(request) - auth_controller = AuthenticationController() - - auth_controller = AuthenticationController() - update_status = auth_controller.remove_user_role( - request.user, org_id, user_email, role - ) - if update_status: - return Response( - status=status.HTTP_200_OK, - data={"status": "success", "message": "success"}, - ) - else: - return Response( - status=status.HTTP_400_BAD_REQUEST, - data={"status": "failed", "message": "failed"}, - ) - - @action(detail=False, methods=["GET"]) - def get_user_profile(self, request: Request) -> Response: - auth_controller = AuthenticationController() - try: - # z_code = request.COOKIES.get(Cookie.Z_CODE) - user_info = auth_controller.get_user_info(request) - role = auth_controller.get_organization_members_by_user(request.user) - if not user_info: - return Response( - status=status.HTTP_404_NOT_FOUND, - data={"message": "User Not Found"}, - ) - serialized_user_info = UserInfoSerializer(user_info).data - # Temporary fix for getting user role along with user info. - # Proper implementation would be adding role field to UserInfo. - serialized_user_info["is_admin"] = auth_controller.is_admin_by_role( - role.role - ) - # changes for displying onboarding msgs - org_member = OrganizationMember.objects.get(user=request.user) - serialized_user_info["login_onboarding_message_displayed"] = ( - org_member.is_login_onboarding_msg - ) - serialized_user_info["prompt_onboarding_message_displayed"] = ( - org_member.is_prompt_studio_onboarding_msg - ) - - return Response( - status=status.HTTP_200_OK, data={"user": serialized_user_info} - ) - except Exception as error: - Logger.error(f"Error while get User : {error}") - return Response( - status=status.HTTP_500_INTERNAL_SERVER_ERROR, - data={"message": "Internal Error"}, - ) - - @action(detail=False, methods=["POST"]) - def invite_user(self, request: Request) -> Response: - serializer = InviteUserSerializer(data=request.data) - serializer.is_valid(raise_exception=True) - user_list = serializer.get_users(serializer.validated_data) - auth_controller = AuthenticationController() - invite_response = auth_controller.invite_user( - admin=request.user, - org_id=UserSessionUtils.get_organization_id(request), - user_list=user_list, - ) - - response_serializer = UserInviteResponseSerializer(invite_response, many=True) - - if invite_response and len(invite_response) != 0: - response = Response( - status=status.HTTP_200_OK, - data={"message": response_serializer.data}, - ) - else: - response = Response( - status=status.HTTP_400_BAD_REQUEST, - data={"message": "failed"}, - ) - return response - - @action(detail=False, methods=["DELETE"]) - def remove_members_from_organization(self, request: Request) -> Response: - serializer = RemoveUserFromOrganizationSerializer(data=request.data) - - serializer.is_valid(raise_exception=True) - user_emails = serializer.get_user_emails(serializer.validated_data) - organization_id: str = UserSessionUtils.get_organization_id(request) - - auth_controller = AuthenticationController() - is_updated = auth_controller.remove_users_from_organization( - admin=request.user, - organization_id=organization_id, - user_emails=user_emails, - ) - if is_updated: - return Response( - status=status.HTTP_200_OK, - data={"status": "success", "message": "success"}, - ) - else: - return Response( - status=status.HTTP_400_BAD_REQUEST, - data={"status": "failed", "message": "failed"}, - ) - - @action(detail=False, methods=["GET"]) - def get_organization_members(self, request: Request) -> Response: - auth_controller = AuthenticationController() - if UserSessionUtils.get_organization_id(request): - members: list[OrganizationMember] = ( - auth_controller.get_organization_members_by_org_id() - ) - serialized_members = OrganizationMemberSerializer(members, many=True).data - return Response( - status=status.HTTP_200_OK, - data={"message": "success", "members": serialized_members}, - ) - return Response( - status=status.HTTP_401_UNAUTHORIZED, - data={"message": "cookie not found"}, - ) - - @action(detail=False, methods=["PUT"]) - def update_flags(self, request: Request) -> Response: - serializer = UpdateFlagSerializer(data=request.data) - serializer.is_valid(raise_exception=True) - org_member = OrganizationMember.objects.get(user=request.user) - org_member.is_login_onboarding_msg = serializer.validated_data.get( - "is_login_onboarding_msg" - ) - - org_member.is_prompt_studio_onboarding_msg = serializer.validated_data.get( - "is_prompt_studio_onboarding_msg" - ) - org_member.save() - return Response( - status=status.HTTP_200_OK, - data={"status": "success", "message": "success"}, - ) diff --git a/backend/tenant_account/views.py b/backend/tenant_account/views.py deleted file mode 100644 index 5f37fd3b2..000000000 --- a/backend/tenant_account/views.py +++ /dev/null @@ -1,89 +0,0 @@ -import logging -from typing import Any - -from account.authentication_controller import AuthenticationController -from account.dto import UserRoleData -from account.models import Organization -from rest_framework import status -from rest_framework.decorators import api_view -from rest_framework.request import Request -from rest_framework.response import Response -from tenant_account.dto import OrganizationLoginResponse, ResetUserPasswordDto -from tenant_account.serializer import ( - GetRolesResponseSerializer, - OrganizationLoginResponseSerializer, -) -from utils.user_session import UserSessionUtils - -logger = logging.getLogger(__name__) - - -@api_view(["GET"]) -def logout(request: Request) -> Response: - auth_controller = AuthenticationController() - return auth_controller.user_logout(request) - - -@api_view(["GET"]) -def get_roles(request: Request) -> Response: - auth_controller = AuthenticationController() - roles: list[UserRoleData] = auth_controller.get_user_roles() - serialized_members = GetRolesResponseSerializer(roles, many=True).data - return Response( - status=status.HTTP_200_OK, - data={"message": "success", "members": serialized_members}, - ) - - -@api_view(["POST"]) -def reset_password(request: Request) -> Response: - auth_controller = AuthenticationController() - data: ResetUserPasswordDto = auth_controller.reset_user_password(request.user) - if data.status: - return Response( - status=status.HTTP_200_OK, - data={"status": "success", "message": data.message}, - ) - else: - return Response( - status=status.HTTP_400_BAD_REQUEST, - data={"status": "failed", "message": data.message}, - ) - - -@api_view(["GET"]) -def get_organization(request: Request) -> Response: - auth_controller = AuthenticationController() - try: - organization_id = UserSessionUtils.get_organization_id(request) - org_data = auth_controller.get_organization_info(organization_id) - if not org_data: - return Response( - status=status.HTTP_404_NOT_FOUND, - data={"message": "Org Not Found"}, - ) - response = makeSignupResponse(org_data) - return Response( - status=status.HTTP_201_CREATED, - data={"message": "success", "organization": response}, - ) - - except Exception as error: - logger.error(f"Error while get User : {error}") - return Response( - status=status.HTTP_500_INTERNAL_SERVER_ERROR, - data={"message": "Internal Error"}, - ) - - -def makeSignupResponse( - organization: Organization, -) -> Any: - return OrganizationLoginResponseSerializer( - OrganizationLoginResponse( - organization.name, - organization.display_name, - organization.organization_id, - organization.created_at, - ) - ).data diff --git a/backend/tool_instance/__init__.py b/backend/tool_instance/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/backend/tool_instance/admin.py b/backend/tool_instance/admin.py deleted file mode 100644 index 1e1c975ca..000000000 --- a/backend/tool_instance/admin.py +++ /dev/null @@ -1,5 +0,0 @@ -from django.contrib import admin - -from .models import ToolInstance - -admin.site.register(ToolInstance) diff --git a/backend/tool_instance/apps.py b/backend/tool_instance/apps.py deleted file mode 100644 index 1dc6e7b10..000000000 --- a/backend/tool_instance/apps.py +++ /dev/null @@ -1,5 +0,0 @@ -from django.apps import AppConfig - - -class ToolInstanceConfig(AppConfig): - name = "tool_instance" diff --git a/backend/tool_instance/constants.py b/backend/tool_instance/constants.py deleted file mode 100644 index 17a3891c5..000000000 --- a/backend/tool_instance/constants.py +++ /dev/null @@ -1,45 +0,0 @@ -class ToolInstanceKey: - """Dict keys for ToolInstance model.""" - - PK = "id" - TOOL_ID = "tool_id" - VERSION = "version" - METADATA = "metadata" - STEP = "step" - STATUS = "status" - WORKFLOW = "workflow" - INPUT = "input" - OUTPUT = "output" - TI_COUNT = "tool_instance_count" - - -class JsonSchemaKey: - """Dict Keys for Tool's Json schema.""" - - PROPERTIES = "properties" - THEN = "then" - INPUT_FILE_CONNECTOR = "inputFileConnector" - OUTPUT_FILE_CONNECTOR = "outputFileConnector" - OUTPUT_FOLDER = "outputFolder" - ROOT_FOLDER = "rootFolder" - TENANT_ID = "tenant_id" - INPUT_DB_CONNECTOR = "inputDBConnector" - OUTPUT_DB_CONNECTOR = "outputDBConnector" - ENUM = "enum" - PROJECT_DEFAULT = "Project Default" - - -class ToolInstanceErrors: - TOOL_EXISTS = "Tool with this configuration already exists." - DUPLICATE_API = "It appears that a duplicate call may have been made." - - -class ToolKey: - """Dict keys for a Tool.""" - - NAME = "name" - DESCRIPTION = "description" - ICON = "icon" - FUNCTION_NAME = "function_name" - OUTPUT_TYPE = "output_type" - INPUT_TYPE = "input_type" diff --git a/backend/tool_instance/exceptions.py b/backend/tool_instance/exceptions.py deleted file mode 100644 index 69c2c26a5..000000000 --- a/backend/tool_instance/exceptions.py +++ /dev/null @@ -1,46 +0,0 @@ -from typing import Optional - -from rest_framework.exceptions import APIException - - -class ToolInstanceBaseException(APIException): - def __init__( - self, - detail: Optional[str] = None, - code: Optional[int] = None, - tool_name: Optional[str] = None, - ) -> None: - detail = detail or self.default_detail - if tool_name is not None: - detail = f"{detail} Tool: {tool_name}" - super().__init__(detail, code) - - -class ToolFunctionIsMandatory(ToolInstanceBaseException): - status_code = 400 - default_detail = "Tool function is mandatory." - - -class ToolDoesNotExist(ToolInstanceBaseException): - status_code = 400 - default_detail = "Tool doesn't exist." - - -class FetchToolListFailed(ToolInstanceBaseException): - status_code = 400 - default_detail = "Failed to fetch tool list." - - -class ToolInstantiationError(ToolInstanceBaseException): - status_code = 500 - default_detail = "Error instantiating tool." - - -class BadRequestException(ToolInstanceBaseException): - status_code = 400 - default_detail = "Invalid input." - - -class ToolSettingValidationError(APIException): - status_code = 400 - default_detail = "Error while validating tool's setting." diff --git a/backend/tool_instance/migrations/0001_initial.py b/backend/tool_instance/migrations/0001_initial.py deleted file mode 100644 index e0ff6fb73..000000000 --- a/backend/tool_instance/migrations/0001_initial.py +++ /dev/null @@ -1,136 +0,0 @@ -# Generated by Django 4.2.1 on 2024-01-23 11:18 - -import uuid - -import django.db.models.deletion -from django.conf import settings -from django.db import migrations, models - - -class Migration(migrations.Migration): - initial = True - - dependencies = [ - migrations.swappable_dependency(settings.AUTH_USER_MODEL), - ("connector", "0001_initial"), - ("workflow", "0001_initial"), - ] - - operations = [ - migrations.CreateModel( - name="ToolInstance", - fields=[ - ("created_at", models.DateTimeField(auto_now_add=True)), - ("modified_at", models.DateTimeField(auto_now=True)), - ( - "id", - models.UUIDField( - default=uuid.uuid4, - editable=False, - primary_key=True, - serialize=False, - ), - ), - ( - "tool_id", - models.CharField( - db_comment="Function name of the tool being used", - max_length=64, - ), - ), - ( - "input", - models.JSONField( - db_comment="Provisional WF input to a tool", null=True - ), - ), - ( - "output", - models.JSONField( - db_comment="Provisional WF output to a tool", null=True - ), - ), - ("version", models.CharField(max_length=16)), - ( - "metadata", - models.JSONField(db_comment="Stores config for a tool"), - ), - ("step", models.IntegerField()), - ( - "status", - models.CharField(default="Ready to start", max_length=32), - ), - ( - "created_by", - models.ForeignKey( - blank=True, - null=True, - on_delete=django.db.models.deletion.SET_NULL, - related_name="created_tools", - to=settings.AUTH_USER_MODEL, - ), - ), - ( - "input_db_connector", - models.ForeignKey( - blank=True, - null=True, - on_delete=django.db.models.deletion.SET_NULL, - related_name="input_db_connector", - to="connector.connectorinstance", - ), - ), - ( - "input_file_connector", - models.ForeignKey( - blank=True, - null=True, - on_delete=django.db.models.deletion.SET_NULL, - related_name="input_file_connector", - to="connector.connectorinstance", - ), - ), - ( - "modified_by", - models.ForeignKey( - blank=True, - null=True, - on_delete=django.db.models.deletion.SET_NULL, - related_name="modified_tools", - to=settings.AUTH_USER_MODEL, - ), - ), - ( - "output_db_connector", - models.ForeignKey( - blank=True, - null=True, - on_delete=django.db.models.deletion.SET_NULL, - related_name="output_db_connector", - to="connector.connectorinstance", - ), - ), - ( - "output_file_connector", - models.ForeignKey( - blank=True, - null=True, - on_delete=django.db.models.deletion.SET_NULL, - related_name="output_file_connector", - to="connector.connectorinstance", - ), - ), - ( - "workflow", - models.ForeignKey( - on_delete=django.db.models.deletion.CASCADE, - related_name="workflow_tool", - to="workflow.workflow", - ), - ), - ], - options={ - "abstract": False, - }, - ), - ] diff --git a/backend/tool_instance/migrations/__init__.py b/backend/tool_instance/migrations/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/backend/tool_instance/models.py b/backend/tool_instance/models.py deleted file mode 100644 index 7ddc73cff..000000000 --- a/backend/tool_instance/models.py +++ /dev/null @@ -1,94 +0,0 @@ -import uuid - -from account.models import User -from connector.models import ConnectorInstance -from django.db import models -from django.db.models import QuerySet -from utils.models.base_model import BaseModel -from workflow_manager.workflow.models.workflow import Workflow - -TOOL_ID_LENGTH = 64 -TOOL_VERSION_LENGTH = 16 -TOOL_STATUS_LENGTH = 32 - - -class ToolInstanceManager(models.Manager): - def get_instances_for_workflow( - self, workflow: uuid.UUID - ) -> QuerySet["ToolInstance"]: - return self.filter(workflow=workflow) - - -class ToolInstance(BaseModel): - class Status(models.TextChoices): - PENDING = "PENDING", "Settings Not Configured" - READY = "READY", "Ready to Start" - INITIATED = "INITIATED", "Initialization in Progress" - COMPLETED = "COMPLETED", "Process Completed" - ERROR = "ERROR", "Error Encountered" - - workflow = models.ForeignKey( - Workflow, - on_delete=models.CASCADE, - related_name="workflow_tool", - null=False, - blank=False, - ) - id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) - tool_id = models.CharField( - max_length=TOOL_ID_LENGTH, - db_comment="Function name of the tool being used", - ) - input = models.JSONField(null=True, db_comment="Provisional WF input to a tool") - output = models.JSONField(null=True, db_comment="Provisional WF output to a tool") - version = models.CharField(max_length=TOOL_VERSION_LENGTH) - metadata = models.JSONField(db_comment="Stores config for a tool") - step = models.IntegerField() - # TODO: Make as an enum supporting fixed values once we have clarity - status = models.CharField(max_length=TOOL_STATUS_LENGTH, default="Ready to start") - created_by = models.ForeignKey( - User, - on_delete=models.SET_NULL, - related_name="created_tools", - null=True, - blank=True, - ) - modified_by = models.ForeignKey( - User, - on_delete=models.SET_NULL, - related_name="modified_tools", - null=True, - blank=True, - ) - # Added these connectors separately - # for file and db for scalability - input_file_connector = models.ForeignKey( - ConnectorInstance, - on_delete=models.SET_NULL, - related_name="input_file_connector", - null=True, - blank=True, - ) - output_file_connector = models.ForeignKey( - ConnectorInstance, - on_delete=models.SET_NULL, - related_name="output_file_connector", - null=True, - blank=True, - ) - input_db_connector = models.ForeignKey( - ConnectorInstance, - on_delete=models.SET_NULL, - related_name="input_db_connector", - null=True, - blank=True, - ) - output_db_connector = models.ForeignKey( - ConnectorInstance, - on_delete=models.SET_NULL, - related_name="output_db_connector", - null=True, - blank=True, - ) - - objects = ToolInstanceManager() diff --git a/backend/tool_instance/serializers.py b/backend/tool_instance/serializers.py deleted file mode 100644 index f36b66f87..000000000 --- a/backend/tool_instance/serializers.py +++ /dev/null @@ -1,132 +0,0 @@ -import logging -import uuid -from typing import Any - -from prompt_studio.prompt_studio_registry.constants import PromptStudioRegistryKeys -from rest_framework.serializers import ListField, Serializer, UUIDField, ValidationError -from tool_instance.constants import ToolInstanceKey as TIKey -from tool_instance.constants import ToolKey -from tool_instance.exceptions import ToolDoesNotExist -from tool_instance.models import ToolInstance -from tool_instance.tool_instance_helper import ToolInstanceHelper -from tool_instance.tool_processor import ToolProcessor -from unstract.tool_registry.dto import Tool -from workflow_manager.workflow.constants import WorkflowKey -from workflow_manager.workflow.models.workflow import Workflow - -from backend.constants import RequestKey -from backend.serializers import AuditSerializer - -logger = logging.getLogger(__name__) - - -class ToolInstanceSerializer(AuditSerializer): - workflow_id = UUIDField(write_only=True) - - class Meta: - model = ToolInstance - fields = "__all__" - extra_kwargs = { - TIKey.WORKFLOW: { - "required": False, - }, - TIKey.VERSION: { - "required": False, - }, - TIKey.METADATA: { - "required": False, - }, - TIKey.STEP: { - "required": False, - }, - } - - def to_representation(self, instance: ToolInstance) -> dict[str, str]: - rep: dict[str, Any] = super().to_representation(instance) - tool_function = rep.get(TIKey.TOOL_ID) - - if tool_function is None: - raise ToolDoesNotExist() - try: - tool: Tool = ToolProcessor.get_tool_by_uid(tool_function) - except ToolDoesNotExist: - return rep - rep[ToolKey.ICON] = tool.icon - rep[ToolKey.NAME] = tool.properties.display_name - # Need to Change it into better method - if self.context.get(RequestKey.REQUEST): - metadata = ToolInstanceHelper.get_altered_metadata(instance) - if metadata: - rep[TIKey.METADATA] = metadata - return rep - - def create(self, validated_data: dict[str, Any]) -> Any: - workflow_id = validated_data.pop(WorkflowKey.WF_ID) - try: - workflow = Workflow.objects.get(pk=workflow_id) - except Workflow.DoesNotExist: - raise ValidationError(f"Workflow with ID {workflow_id} does not exist.") - validated_data[TIKey.WORKFLOW] = workflow - - if workflow.workflow_tool.count() > 0: - raise ValidationError( - f"Workflow with ID {workflow_id} can't have more than one tool." - ) - - tool_uid = validated_data.get(TIKey.TOOL_ID) - if not tool_uid: - raise ToolDoesNotExist() - - tool: Tool = ToolProcessor.get_tool_by_uid(tool_uid=tool_uid) - # TODO: Handle other fields once tools SDK is out - validated_data[TIKey.PK] = uuid.uuid4() - # TODO: Use version from tool props - validated_data[TIKey.VERSION] = "" - validated_data[TIKey.METADATA] = { - # TODO: Review and remove tool instance ID - WorkflowKey.WF_TOOL_INSTANCE_ID: str(validated_data[TIKey.PK]), - PromptStudioRegistryKeys.PROMPT_REGISTRY_ID: str(tool_uid), - **ToolProcessor.get_default_settings(tool), - } - if TIKey.STEP not in validated_data: - validated_data[TIKey.STEP] = workflow.workflow_tool.count() + 1 - # Workflow will get activated on adding tools to workflow - if not workflow.is_active: - workflow.is_active = True - workflow.save() - return super().create(validated_data) - - -class ToolInstanceReorderSerializer(Serializer): - workflow_id = UUIDField() - tool_instances = ListField(child=UUIDField()) - - def validate(self, data: dict[str, Any]) -> dict[str, Any]: - workflow_id = data.get(WorkflowKey.WF_ID) - tool_instances = data.get(WorkflowKey.WF_TOOL_INSTANCES, []) - - # Check if the workflow exists - try: - workflow = Workflow.objects.get(pk=workflow_id) - except Workflow.DoesNotExist: - raise ValidationError(f"Workflow with ID {workflow_id} does not exist.") - - # Check if the number of tool instances matches the actual count - tool_instance_count = workflow.workflow_tool.count() - if len(tool_instances) != tool_instance_count: - msg = ( - f"Incorrect number of tool instances passed: " - f"{len(tool_instances)}, expected: {tool_instance_count}" - ) - logger.error(msg) - raise ValidationError(detail=msg) - - # Check if each tool instance exists in the workflow - existing_tool_instance_ids = workflow.workflow_tool.values_list("id", flat=True) - for tool_instance_id in tool_instances: - if tool_instance_id not in existing_tool_instance_ids: - raise ValidationError( - "One or more tool instances do not exist in the workflow." - ) - - return data diff --git a/backend/tool_instance/tests.py b/backend/tool_instance/tests.py deleted file mode 100644 index a39b155ac..000000000 --- a/backend/tool_instance/tests.py +++ /dev/null @@ -1 +0,0 @@ -# Create your tests here. diff --git a/backend/tool_instance/tool_instance_helper.py b/backend/tool_instance/tool_instance_helper.py deleted file mode 100644 index b9d92c8c4..000000000 --- a/backend/tool_instance/tool_instance_helper.py +++ /dev/null @@ -1,470 +0,0 @@ -import logging -import os -import uuid -from typing import Any, Optional - -from account.models import User -from adapter_processor.adapter_processor import AdapterProcessor -from adapter_processor.models import AdapterInstance -from connector.connector_instance_helper import ConnectorInstanceHelper -from django.core.exceptions import PermissionDenied -from django.core.exceptions import ValidationError as DjangoValidationError -from jsonschema.exceptions import ValidationError as JSONValidationError -from prompt_studio.prompt_studio_registry.models import PromptStudioRegistry -from tool_instance.constants import JsonSchemaKey -from tool_instance.exceptions import ToolSettingValidationError -from tool_instance.models import ToolInstance -from tool_instance.tool_processor import ToolProcessor -from unstract.sdk.adapters.enums import AdapterTypes -from unstract.sdk.tool.validator import DefaultsGeneratingValidator -from unstract.tool_registry.constants import AdapterPropertyKey -from unstract.tool_registry.dto import Spec, Tool -from unstract.tool_registry.tool_utils import ToolUtils -from workflow_manager.workflow.constants import WorkflowKey - -logger = logging.getLogger(__name__) - - -class ToolInstanceHelper: - @staticmethod - def get_tool_instances_by_workflow( - workflow_id: str, - order_by: str, - lookup: Optional[dict[str, Any]] = None, - offset: Optional[int] = None, - limit: Optional[int] = None, - ) -> list[ToolInstance]: - wf_filter = {} - if lookup: - wf_filter = lookup - wf_filter[WorkflowKey.WF_ID] = workflow_id - - if limit: - offset_value = 0 if not offset else offset - to = offset_value + limit - return list( - ToolInstance.objects.filter(**wf_filter)[offset_value:to].order_by( - order_by - ) - ) - return list(ToolInstance.objects.filter(**wf_filter).all().order_by(order_by)) - - @staticmethod - def update_instance_metadata( - org_id: str, tool_instance: ToolInstance, metadata: dict[str, Any] - ) -> None: - if ( - JsonSchemaKey.OUTPUT_FILE_CONNECTOR in metadata - and JsonSchemaKey.OUTPUT_FOLDER in metadata - ): - output_connector_name = metadata[JsonSchemaKey.OUTPUT_FILE_CONNECTOR] - output_connector = ConnectorInstanceHelper.get_output_connector_instance_by_name_for_workflow( # noqa - tool_instance.workflow_id, output_connector_name - ) - if output_connector and "path" in output_connector.metadata: - metadata[JsonSchemaKey.OUTPUT_FOLDER] = os.path.join( - output_connector.metadata["path"], - *(metadata[JsonSchemaKey.OUTPUT_FOLDER].split("/")), - ) - if ( - JsonSchemaKey.INPUT_FILE_CONNECTOR in metadata - and JsonSchemaKey.ROOT_FOLDER in metadata - ): - input_connector_name = metadata[JsonSchemaKey.INPUT_FILE_CONNECTOR] - input_connector = ConnectorInstanceHelper.get_input_connector_instance_by_name_for_workflow( # noqa - tool_instance.workflow_id, input_connector_name - ) - - if input_connector and "path" in input_connector.metadata: - metadata[JsonSchemaKey.ROOT_FOLDER] = os.path.join( - input_connector.metadata["path"], - *(metadata[JsonSchemaKey.ROOT_FOLDER].split("/")), - ) - ToolInstanceHelper.update_metadata_with_adapter_instances( - metadata, tool_instance.tool_id - ) - metadata[JsonSchemaKey.TENANT_ID] = org_id - tool_instance.metadata = metadata - tool_instance.save() - - @staticmethod - def update_metadata_with_adapter_properties( - metadata: dict[str, Any], - adapter_key: str, - adapter_property: dict[str, Any], - adapter_type: AdapterTypes, - ) -> None: - """Update the metadata dictionary with adapter properties. - - Parameters: - metadata (dict[str, Any]): - The metadata dictionary to be updated with adapter properties. - adapter_key (str): - The key in the metadata dictionary corresponding to the adapter. - adapter_property (dict[str, Any]): - The properties of the adapter. - adapter_type (AdapterTypes): - The type of the adapter. - - Returns: - None - """ - if adapter_key in metadata: - adapter_name = metadata[adapter_key] - adapter = AdapterProcessor.get_adapter_by_name_and_type( - adapter_type=adapter_type, adapter_name=adapter_name - ) - adapter_id = str(adapter.id) if adapter else None - metadata_key_for_id = adapter_property.get( - AdapterPropertyKey.ADAPTER_ID_KEY, AdapterPropertyKey.ADAPTER_ID - ) - metadata[metadata_key_for_id] = adapter_id - - @staticmethod - def update_metadata_with_adapter_instances( - metadata: dict[str, Any], tool_uid: str - ) -> None: - """ - Update the metadata dictionary with adapter instances. - Parameters: - metadata (dict[str, Any]): - The metadata dictionary to be updated with adapter instances. - - Returns: - None - """ - tool: Tool = ToolProcessor.get_tool_by_uid(tool_uid=tool_uid) - schema: Spec = ToolUtils.get_json_schema_for_tool(tool) - llm_properties = schema.get_llm_adapter_properties() - embedding_properties = schema.get_embedding_adapter_properties() - vector_db_properties = schema.get_vector_db_adapter_properties() - x2text_properties = schema.get_text_extractor_adapter_properties() - ocr_properties = schema.get_ocr_adapter_properties() - - for adapter_key, adapter_property in llm_properties.items(): - ToolInstanceHelper.update_metadata_with_adapter_properties( - metadata=metadata, - adapter_key=adapter_key, - adapter_property=adapter_property, - adapter_type=AdapterTypes.LLM, - ) - - for adapter_key, adapter_property in embedding_properties.items(): - ToolInstanceHelper.update_metadata_with_adapter_properties( - metadata=metadata, - adapter_key=adapter_key, - adapter_property=adapter_property, - adapter_type=AdapterTypes.EMBEDDING, - ) - - for adapter_key, adapter_property in vector_db_properties.items(): - ToolInstanceHelper.update_metadata_with_adapter_properties( - metadata=metadata, - adapter_key=adapter_key, - adapter_property=adapter_property, - adapter_type=AdapterTypes.VECTOR_DB, - ) - - for adapter_key, adapter_property in x2text_properties.items(): - ToolInstanceHelper.update_metadata_with_adapter_properties( - metadata=metadata, - adapter_key=adapter_key, - adapter_property=adapter_property, - adapter_type=AdapterTypes.X2TEXT, - ) - - for adapter_key, adapter_property in ocr_properties.items(): - ToolInstanceHelper.update_metadata_with_adapter_properties( - metadata=metadata, - adapter_key=adapter_key, - adapter_property=adapter_property, - adapter_type=AdapterTypes.OCR, - ) - - # TODO: Review if adding this metadata is still required - @staticmethod - def get_altered_metadata( - tool_instance: ToolInstance, - ) -> Optional[dict[str, Any]]: - """Get altered metadata by resolving relative paths. - - This method retrieves the metadata from the given tool instance - and checks if there are output and input file connectors. - If output and input file connectors exist in the metadata, - it resolves the relative paths using connector instances. - - Args: - tool_instance (ToolInstance). - - Returns: - Optional[dict[str, Any]]: Altered metadata with resolved relative \ - paths. - """ - metadata: dict[str, Any] = tool_instance.metadata - if ( - JsonSchemaKey.OUTPUT_FILE_CONNECTOR in metadata - and JsonSchemaKey.OUTPUT_FOLDER in metadata - ): - output_connector_name = metadata[JsonSchemaKey.OUTPUT_FILE_CONNECTOR] - output_connector = ConnectorInstanceHelper.get_output_connector_instance_by_name_for_workflow( # noqa - tool_instance.workflow_id, output_connector_name - ) - if output_connector and "path" in output_connector.metadata: - relative_path = ToolInstanceHelper.get_relative_path( - metadata[JsonSchemaKey.OUTPUT_FOLDER], - output_connector.metadata["path"], - ) - metadata[JsonSchemaKey.OUTPUT_FOLDER] = relative_path - if ( - JsonSchemaKey.INPUT_FILE_CONNECTOR in metadata - and JsonSchemaKey.ROOT_FOLDER in metadata - ): - input_connector_name = metadata[JsonSchemaKey.INPUT_FILE_CONNECTOR] - input_connector = ConnectorInstanceHelper.get_input_connector_instance_by_name_for_workflow( # noqa - tool_instance.workflow_id, input_connector_name - ) - if input_connector and "path" in input_connector.metadata: - relative_path = ToolInstanceHelper.get_relative_path( - metadata[JsonSchemaKey.ROOT_FOLDER], - input_connector.metadata["path"], - ) - metadata[JsonSchemaKey.ROOT_FOLDER] = relative_path - return metadata - - @staticmethod - def update_metadata_with_default_adapter( - adapter_type: AdapterTypes, - schema_spec: Spec, - adapter: AdapterInstance, - metadata: dict[str, Any], - ) -> None: - """Update the metadata of a tool instance with default values for - enabled adapters. - - Parameters: - adapter_type (AdapterTypes): The type of adapter to update - the metadata for. - schema_spec (Spec): The schema specification for the tool. - adapter (AdapterInstance): The adapter instance to use for updating - the metadata. - metadata (dict[str, Any]): The metadata dictionary to update. - - Returns: - None - """ - properties = {} - if adapter_type == AdapterTypes.LLM: - properties = schema_spec.get_llm_adapter_properties() - if adapter_type == AdapterTypes.EMBEDDING: - properties = schema_spec.get_embedding_adapter_properties() - if adapter_type == AdapterTypes.VECTOR_DB: - properties = schema_spec.get_vector_db_adapter_properties() - if adapter_type == AdapterTypes.X2TEXT: - properties = schema_spec.get_text_extractor_adapter_properties() - if adapter_type == AdapterTypes.OCR: - properties = schema_spec.get_ocr_adapter_properties() - for adapter_key, adapter_property in properties.items(): - metadata_key_for_id = adapter_property.get( - AdapterPropertyKey.ADAPTER_ID_KEY, AdapterPropertyKey.ADAPTER_ID - ) - metadata[adapter_key] = adapter.adapter_name - metadata[metadata_key_for_id] = str(adapter.id) - - @staticmethod - def update_metadata_with_default_values( - tool_instance: ToolInstance, user: User - ) -> None: - """Update the metadata of a tool instance with default values for - enabled adapters. - - Parameters: - tool_instance (ToolInstance): The tool instance to update the - metadata. - - Returns: - None - """ - metadata: dict[str, Any] = tool_instance.metadata - tool_uuid = tool_instance.tool_id - - tool: Tool = ToolProcessor.get_tool_by_uid(tool_uid=tool_uuid) - schema: Spec = ToolUtils.get_json_schema_for_tool(tool) - - default_adapters = AdapterProcessor.get_default_adapters(user=user) - for adapter in default_adapters: - try: - adapter_type = AdapterTypes(adapter.adapter_type) - ToolInstanceHelper.update_metadata_with_default_adapter( - adapter_type=adapter_type, - schema_spec=schema, - adapter=adapter, - metadata=metadata, - ) - except ValueError: - logger.warning(f"Invalid AdapterType {adapter.adapter_type}") - tool_instance.metadata = metadata - tool_instance.save() - - @staticmethod - def get_relative_path(absolute_path: str, base_path: str) -> str: - if absolute_path.startswith(base_path): - relative_path = os.path.relpath(absolute_path, base_path) - else: - relative_path = absolute_path - if relative_path == ".": - relative_path = "" - return relative_path - - @staticmethod - def reorder_tool_instances(instances_to_reorder: list[uuid.UUID]) -> None: - """Reorders tool instances based on the list of tool UUIDs received. - Saves the instance in the DB. - - Args: - instances_to_reorder (list[uuid.UUID]): Desired order of tool UUIDs - """ - logger.info(f"Reordering instances: {instances_to_reorder}") - for step, tool_instance_id in enumerate(instances_to_reorder): - tool_instance = ToolInstance.objects.get(pk=tool_instance_id) - tool_instance.step = step + 1 - tool_instance.save() - - @staticmethod - def validate_tool_settings( - user: User, tool_uid: str, tool_meta: dict[str, Any] - ) -> bool: - """Function to validate Tools settings.""" - - # check if exported tool is valid for the user who created workflow - ToolInstanceHelper.validate_tool_access(user=user, tool_uid=tool_uid) - ToolInstanceHelper.validate_adapter_permissions( - user=user, tool_uid=tool_uid, tool_meta=tool_meta - ) - - tool: Tool = ToolProcessor.get_tool_by_uid(tool_uid=tool_uid) - tool_name: str = ( - tool.properties.display_name if tool.properties.display_name else tool_uid - ) - schema_json: dict[str, Any] = ToolProcessor.get_json_schema_for_tool( - tool_uid=tool_uid, user=user - ) - try: - DefaultsGeneratingValidator(schema_json).validate(tool_meta) - except JSONValidationError as e: - logger.error(e, stack_info=True, exc_info=True) - err_msg = e.message - # TODO: Support other JSON validation errors or consider following - # https://github.com/networknt/json-schema-validator/blob/master/doc/cust-msg.md - if e.validator == "required": - for validator_val in e.validator_value: - required_prop = e.schema.get("properties").get(validator_val) - required_display_name = required_prop.get("title") - err_msg = err_msg.replace(validator_val, required_display_name) - elif e.validator == "minItems": - validated_entity_display_name = e.schema.get("title") - err_msg = ( - f"'{validated_entity_display_name}' requires atleast" - f" {e.validator_value} values." - ) - elif e.validator == "maxItems": - validated_entity_display_name = e.schema.get("title") - err_msg = ( - f"'{validated_entity_display_name}' requires atmost" - f" {e.validator_value} values." - ) - else: - logger.warning(f"Unformatted exception sent to user: {err_msg}") - raise ToolSettingValidationError( - f"Error validating tool settings for '{tool_name}': {err_msg}" - ) - return True - - @staticmethod - def validate_adapter_permissions( - user: User, tool_uid: str, tool_meta: dict[str, Any] - ) -> None: - tool: Tool = ToolProcessor.get_tool_by_uid(tool_uid=tool_uid) - adapter_ids: set[str] = set() - - for llm in tool.properties.adapter.language_models: - if llm.is_enabled and llm.adapter_id: - adapter_id = tool_meta[llm.adapter_id] - elif llm.is_enabled: - adapter_id = tool_meta[AdapterPropertyKey.DEFAULT_LLM_ADAPTER_ID] - - adapter_ids.add(adapter_id) - for vdb in tool.properties.adapter.vector_stores: - if vdb.is_enabled and vdb.adapter_id: - adapter_id = tool_meta[vdb.adapter_id] - elif vdb.is_enabled: - adapter_id = tool_meta[AdapterPropertyKey.DEFAULT_VECTOR_DB_ADAPTER_ID] - - adapter_ids.add(adapter_id) - for embedding in tool.properties.adapter.embedding_services: - if embedding.is_enabled and embedding.adapter_id: - adapter_id = tool_meta[embedding.adapter_id] - elif embedding.is_enabled: - adapter_id = tool_meta[AdapterPropertyKey.DEFAULT_EMBEDDING_ADAPTER_ID] - - adapter_ids.add(adapter_id) - for text_extractor in tool.properties.adapter.text_extractors: - if text_extractor.is_enabled and text_extractor.adapter_id: - adapter_id = tool_meta[text_extractor.adapter_id] - elif text_extractor.is_enabled: - adapter_id = tool_meta[AdapterPropertyKey.DEFAULT_X2TEXT_ADAPTER_ID] - - adapter_ids.add(adapter_id) - - ToolInstanceHelper.validate_adapter_access(user=user, adapter_ids=adapter_ids) - - @staticmethod - def validate_adapter_access( - user: User, - adapter_ids: set[str], - ) -> None: - adapter_instances = AdapterInstance.objects.filter(id__in=adapter_ids).all() - - for adapter_instance in adapter_instances: - if not adapter_instance.is_usable: - logger.error( - "Free usage for the configured sample adapter %s exhausted", - adapter_instance.id, - ) - error_msg = "Permission Error: Free usage for the configured trial adapter exhausted.Please connect your own service accounts to continue.Please see our documentation for more details:https://docs.unstract.com/unstract_platform/setup_accounts/whats_needed" # noqa: E501 - - raise PermissionDenied(error_msg) - - if not ( - adapter_instance.shared_to_org - or adapter_instance.created_by == user - or adapter_instance.shared_users.filter(pk=user.pk).exists() - ): - logger.error( - "User %s doesn't have access to adapter %s", - user.user_id, - adapter_instance.id, - ) - raise PermissionDenied( - "You don't have permission to perform this action." - ) - - @staticmethod - def validate_tool_access( - user: User, - tool_uid: str, - ) -> None: - # HACK: Assume tool_uid is a prompt studio exported tool and query it. - # We suppress ValidationError when tool_uid is of a static tool. - try: - prompt_registry_tool = PromptStudioRegistry.objects.get(pk=tool_uid) - except DjangoValidationError: - logger.info(f"Not validating tool access for tool: {tool_uid}") - return - - if ( - prompt_registry_tool.shared_to_org - or prompt_registry_tool.shared_users.filter(pk=user.pk).exists() - ): - return - else: - raise PermissionDenied("You don't have permission to perform this action.") diff --git a/backend/tool_instance/tool_processor.py b/backend/tool_instance/tool_processor.py deleted file mode 100644 index b7584b419..000000000 --- a/backend/tool_instance/tool_processor.py +++ /dev/null @@ -1,130 +0,0 @@ -import logging -from typing import Any, Optional - -from account.models import User -from adapter_processor.adapter_processor import AdapterProcessor -from prompt_studio.prompt_studio_registry.prompt_studio_registry_helper import ( - PromptStudioRegistryHelper, -) -from tool_instance.exceptions import ToolDoesNotExist -from unstract.sdk.adapters.enums import AdapterTypes -from unstract.tool_registry.dto import Spec, Tool -from unstract.tool_registry.tool_registry import ToolRegistry -from unstract.tool_registry.tool_utils import ToolUtils - -logger = logging.getLogger(__name__) - - -class ToolProcessor: - TOOL_NOT_IN_REGISTRY_MESSAGE = "Tool does not exist in registry" - tool_registry = ToolRegistry() - - @staticmethod - def get_tool_by_uid(tool_uid: str) -> Tool: - """Function to get and instantiate a tool for a given tool - settingsId.""" - tool_registry = ToolRegistry() - tool: Optional[Tool] = tool_registry.get_tool_by_uid(tool_uid) - # HACK: Assume tool_uid is prompt_registry_id for fetching a dynamic - # tool made with Prompt Studio. - if not tool: - tool = PromptStudioRegistryHelper.get_tool_by_prompt_registry_id( - prompt_registry_id=tool_uid - ) - if not tool: - raise ToolDoesNotExist( - f"{ToolProcessor.TOOL_NOT_IN_REGISTRY_MESSAGE}: {tool_uid}" - ) - return tool - - @staticmethod - def get_default_settings(tool: Tool) -> dict[str, str]: - """Function to make and fill settings with default values. - - Args: - tool (ToolSettings): tool - - Returns: - dict[str, str]: tool settings - """ - tool_metadata: dict[str, str] = ToolUtils.get_default_settings(tool) - return tool_metadata - - @staticmethod - def get_json_schema_for_tool(tool_uid: str, user: User) -> dict[str, str]: - """Function to Get JSON Schema for Tools.""" - tool: Tool = ToolProcessor.get_tool_by_uid(tool_uid=tool_uid) - schema: Spec = ToolUtils.get_json_schema_for_tool(tool) - ToolProcessor.update_schema_with_adapter_configurations( - schema=schema, user=user - ) - schema_json: dict[str, Any] = schema.to_dict() - return schema_json - - @staticmethod - def update_schema_with_adapter_configurations(schema: Spec, user: User) -> None: - """Updates the JSON schema with the available adapter configurations - for the LLM, embedding, and vector DB adapters. - - Args: - schema (Spec): The JSON schema object to be updated. - - Returns: - None. The `schema` object is updated in-place. - """ - llm_keys = schema.get_llm_adapter_properties_keys() - embedding_keys = schema.get_embedding_adapter_properties_keys() - vector_db_keys = schema.get_vector_db_adapter_properties_keys() - x2text_keys = schema.get_text_extractor_adapter_properties_keys() - ocr_keys = schema.get_ocr_adapter_properties_keys() - - if llm_keys: - adapters = AdapterProcessor.get_adapters_by_type( - AdapterTypes.LLM, user=user - ) - for key in llm_keys: - adapter_names = map(lambda adapter: str(adapter.adapter_name), adapters) - schema.properties[key]["enum"] = list(adapter_names) - - if embedding_keys: - adapters = AdapterProcessor.get_adapters_by_type( - AdapterTypes.EMBEDDING, user=user - ) - for key in embedding_keys: - adapter_names = map(lambda adapter: str(adapter.adapter_name), adapters) - schema.properties[key]["enum"] = list(adapter_names) - - if vector_db_keys: - adapters = AdapterProcessor.get_adapters_by_type( - AdapterTypes.VECTOR_DB, user=user - ) - for key in vector_db_keys: - adapter_names = map(lambda adapter: str(adapter.adapter_name), adapters) - schema.properties[key]["enum"] = list(adapter_names) - - if x2text_keys: - adapters = AdapterProcessor.get_adapters_by_type( - AdapterTypes.X2TEXT, user=user - ) - for key in x2text_keys: - adapter_names = map(lambda adapter: str(adapter.adapter_name), adapters) - schema.properties[key]["enum"] = list(adapter_names) - - if ocr_keys: - adapters = AdapterProcessor.get_adapters_by_type( - AdapterTypes.OCR, user=user - ) - for key in ocr_keys: - adapter_names = map(lambda adapter: str(adapter.adapter_name), adapters) - schema.properties[key]["enum"] = list(adapter_names) - - @staticmethod - def get_tool_list(user: User) -> list[dict[str, Any]]: - """Function to get a list of tools.""" - tool_registry = ToolRegistry() - prompt_studio_tools: list[dict[str, Any]] = ( - PromptStudioRegistryHelper.fetch_json_for_registry(user) - ) - tool_list: list[dict[str, Any]] = tool_registry.fetch_tools_descriptions() - tool_list = tool_list + prompt_studio_tools - return tool_list diff --git a/backend/tool_instance/urls.py b/backend/tool_instance/urls.py deleted file mode 100644 index dbcb75189..000000000 --- a/backend/tool_instance/urls.py +++ /dev/null @@ -1,46 +0,0 @@ -from django.urls import path -from rest_framework.urlpatterns import format_suffix_patterns -from tool_instance.views import ToolInstanceViewSet - -from . import views - -tool_instance_list = ToolInstanceViewSet.as_view( - { - "get": "list", - "post": "create", - } -) -tool_instance_detail = ToolInstanceViewSet.as_view( - # fmt: off - { - "get": "retrieve", - "put": "update", - "patch": "partial_update", - "delete": "destroy" - } - # fmt: on -) - -tool_instance_reorder = ToolInstanceViewSet.as_view({"post": "reorder"}) - -urlpatterns = format_suffix_patterns( - [ - path("tool_instance/", tool_instance_list, name="tool-instance-list"), - path( - "tool_instance//", - tool_instance_detail, - name="tool-instance-detail", - ), - path( - "tool_settings_schema/", - views.tool_settings_schema, - name="tool_settings_schema", - ), - path( - "tool_instance/reorder/", - tool_instance_reorder, - name="tool_instance_reorder", - ), - path("tool/", views.get_tool_list, name="tool_list"), - ] -) diff --git a/backend/tool_instance/views.py b/backend/tool_instance/views.py deleted file mode 100644 index b5a1bc237..000000000 --- a/backend/tool_instance/views.py +++ /dev/null @@ -1,167 +0,0 @@ -import logging -import uuid -from typing import Any - -from account.custom_exceptions import DuplicateData -from django.db import IntegrityError -from django.db.models.query import QuerySet -from rest_framework import serializers, status, viewsets -from rest_framework.decorators import api_view -from rest_framework.request import Request -from rest_framework.response import Response -from rest_framework.versioning import URLPathVersioning -from tool_instance.constants import ToolInstanceErrors -from tool_instance.constants import ToolInstanceKey as TIKey -from tool_instance.constants import ToolKey -from tool_instance.exceptions import FetchToolListFailed, ToolFunctionIsMandatory -from tool_instance.models import ToolInstance -from tool_instance.serializers import ( - ToolInstanceReorderSerializer as TIReorderSerializer, -) -from tool_instance.serializers import ToolInstanceSerializer -from tool_instance.tool_instance_helper import ToolInstanceHelper -from tool_instance.tool_processor import ToolProcessor -from utils.filtering import FilterHelper -from utils.user_session import UserSessionUtils -from workflow_manager.workflow.constants import WorkflowKey - -from backend.constants import RequestKey - -logger = logging.getLogger(__name__) - - -@api_view(["GET"]) -def tool_settings_schema(request: Request) -> Response: - if request.method == "GET": - tool_function = request.GET.get(ToolKey.FUNCTION_NAME) - if tool_function is None or tool_function == "": - raise ToolFunctionIsMandatory() - - json_schema = ToolProcessor.get_json_schema_for_tool( - tool_uid=tool_function, user=request.user - ) - return Response(data=json_schema, status=status.HTTP_200_OK) - - -@api_view(("GET",)) -def get_tool_list(request: Request) -> Response: - """Get tool list. - - Fetches a list of tools available in the Tool registry - """ - if request.method == "GET": - try: - logger.info("Fetching tools from the tool registry...") - return Response( - data=ToolProcessor.get_tool_list(request.user), - status=status.HTTP_200_OK, - ) - except Exception as exc: - logger.error(f"Failed to fetch tools: {exc}") - raise FetchToolListFailed - - -class ToolInstanceViewSet(viewsets.ModelViewSet): - versioning_class = URLPathVersioning - queryset = ToolInstance.objects.all() - serializer_class = ToolInstanceSerializer - - def get_queryset(self) -> QuerySet: - filterArgs = FilterHelper.build_filter_args( - self.request, - RequestKey.PROJECT, - RequestKey.CREATED_BY, - RequestKey.WORKFLOW, - ) - if filterArgs: - queryset = ToolInstance.objects.filter( - created_by=self.request.user, **filterArgs - ) - else: - queryset = ToolInstance.objects.filter(created_by=self.request.user) - return queryset - - def get_serializer_class(self) -> serializers.Serializer: - if self.action == "reorder": - return TIReorderSerializer - else: - return ToolInstanceSerializer - - def create(self, request: Any) -> Response: - """Create tool instance. - - Creates a tool instance, useful to add them directly to a - workflow. Its an alternative to creating tool instances through - the LLM response. - """ - - serializer = self.get_serializer(data=request.data) - serializer.is_valid(raise_exception=True) - try: - self.perform_create(serializer) - except IntegrityError: - raise DuplicateData( - f"{ToolInstanceErrors.TOOL_EXISTS}, " - f"{ToolInstanceErrors.DUPLICATE_API}" - ) - instance: ToolInstance = serializer.instance - ToolInstanceHelper.update_metadata_with_default_values( - instance, user=request.user - ) - headers = self.get_success_headers(serializer.data) - return Response( - serializer.data, status=status.HTTP_201_CREATED, headers=headers - ) - - def perform_destroy(self, instance: ToolInstance) -> None: - """Deletes a tool instance and decrements successor instance's steps. - - Args: - instance (ToolInstance): Instance being deleted. - """ - lookup = {"step__gt": instance.step} - next_tool_instances: list[ToolInstance] = ( - ToolInstanceHelper.get_tool_instances_by_workflow( - instance.workflow.id, TIKey.STEP, lookup=lookup - ) - ) - super().perform_destroy(instance) - - for instance in next_tool_instances: - instance.step = instance.step - 1 - instance.save() - return - - def partial_update(self, request: Request, *args: Any, **kwargs: Any) -> Response: - """Allows partial updates on a tool instance.""" - instance: ToolInstance = self.get_object() - serializer = self.get_serializer(instance, data=request.data, partial=True) - serializer.is_valid(raise_exception=True) - if serializer.validated_data.get(TIKey.METADATA): - metadata: dict[str, Any] = serializer.validated_data.get(TIKey.METADATA) - - # TODO: Move update logic into serializer - ToolInstanceHelper.update_instance_metadata( - UserSessionUtils.get_organization_id(request), - instance, - metadata, - ) - return Response(serializer.data) - return super().partial_update(request, *args, **kwargs) - - def reorder(self, request: Any, **kwargs: Any) -> Response: - """Reorder tool instances. - - Reorders the tool instances based on a list of UUIDs. - """ - serializer: TIReorderSerializer = self.get_serializer(data=request.data) - serializer.is_valid(raise_exception=True) - wf_id = serializer.validated_data[WorkflowKey.WF_ID] - instances_to_reorder: list[uuid.UUID] = serializer.validated_data[ - WorkflowKey.WF_TOOL_INSTANCES - ] - - ToolInstanceHelper.reorder_tool_instances(instances_to_reorder) - tool_instances = ToolInstance.objects.get_instances_for_workflow(workflow=wf_id) - ti_serializer = ToolInstanceSerializer(instance=tool_instances, many=True) - return Response(ti_serializer.data, status=status.HTTP_200_OK) diff --git a/backend/usage/__init__.py b/backend/usage/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/backend/usage/admin.py b/backend/usage/admin.py deleted file mode 100644 index c7469a4fb..000000000 --- a/backend/usage/admin.py +++ /dev/null @@ -1,5 +0,0 @@ -from django.contrib import admin - -from .models import Usage - -admin.site.register(Usage) diff --git a/backend/usage/apps.py b/backend/usage/apps.py deleted file mode 100644 index abe8f39c4..000000000 --- a/backend/usage/apps.py +++ /dev/null @@ -1,5 +0,0 @@ -from django.apps import AppConfig - - -class UsageConfig(AppConfig): - name = "usage" diff --git a/backend/usage/constants.py b/backend/usage/constants.py deleted file mode 100644 index 8da54da05..000000000 --- a/backend/usage/constants.py +++ /dev/null @@ -1,7 +0,0 @@ -class UsageKeys: - RUN_ID = "run_id" - EMBEDDING_TOKENS = "embedding_tokens" - PROMPT_TOKENS = "prompt_tokens" - COMPLETION_TOKENS = "completion_tokens" - TOTAL_TOKENS = "total_tokens" - COST_IN_DOLLARS = "cost_in_dollars" diff --git a/backend/usage/helper.py b/backend/usage/helper.py deleted file mode 100644 index fd217cd6f..000000000 --- a/backend/usage/helper.py +++ /dev/null @@ -1,64 +0,0 @@ -import logging - -from django.db.models import Sum -from rest_framework.exceptions import APIException - -from .constants import UsageKeys -from .models import Usage - -logger = logging.getLogger(__name__) - - -class UsageHelper: - @staticmethod - def get_aggregated_token_count(run_id: str) -> dict: - """Retrieve aggregated token counts for the given run_id. - - Args: - run_id (str): The identifier for the token usage. - - Returns: - dict: A dictionary containing aggregated token counts - for different token types. - Keys: - - 'embedding_tokens': Total embedding tokens. - - 'prompt_tokens': Total prompt tokens. - - 'completion_tokens': Total completion tokens. - - 'total_tokens': Total tokens. - - Raises: - APIException: For unexpected errors during database operations. - """ - try: - # Aggregate the token counts for the given run_id - usage_summary = Usage.objects.filter(run_id=run_id).aggregate( - embedding_tokens=Sum(UsageKeys.EMBEDDING_TOKENS), - prompt_tokens=Sum(UsageKeys.PROMPT_TOKENS), - completion_tokens=Sum(UsageKeys.COMPLETION_TOKENS), - total_tokens=Sum(UsageKeys.TOTAL_TOKENS), - cost_in_dollars=Sum(UsageKeys.COST_IN_DOLLARS), - ) - - logger.debug(f"Token counts aggregated successfully for run_id: {run_id}") - - # Prepare the result dictionary with None as the default value - result = { - UsageKeys.EMBEDDING_TOKENS: usage_summary.get( - UsageKeys.EMBEDDING_TOKENS - ), - UsageKeys.PROMPT_TOKENS: usage_summary.get(UsageKeys.PROMPT_TOKENS), - UsageKeys.COMPLETION_TOKENS: usage_summary.get( - UsageKeys.COMPLETION_TOKENS - ), - UsageKeys.TOTAL_TOKENS: usage_summary.get(UsageKeys.TOTAL_TOKENS), - UsageKeys.COST_IN_DOLLARS: usage_summary.get(UsageKeys.COST_IN_DOLLARS), - } - return result - except Usage.DoesNotExist: - # Handle the case where no usage data is found for the given run_id - logger.warning(f"Usage data not found for the specified run_id: {run_id}") - return {} - except Exception as e: - # Handle any other exceptions that might occur during the execution - logger.error(f"An unexpected error occurred for run_id {run_id}: {str(e)}") - raise APIException("Error while aggregating token counts") diff --git a/backend/usage/migrations/0001_initial.py b/backend/usage/migrations/0001_initial.py deleted file mode 100644 index 3bad77c1f..000000000 --- a/backend/usage/migrations/0001_initial.py +++ /dev/null @@ -1,44 +0,0 @@ -# Generated by Django 4.2.1 on 2024-04-22 12:55 - -import uuid - -from django.db import migrations, models - - -class Migration(migrations.Migration): - - initial = True - - dependencies = [] - - operations = [ - migrations.CreateModel( - name="Usage", - fields=[ - ("created_at", models.DateTimeField(auto_now_add=True)), - ("modified_at", models.DateTimeField(auto_now=True)), - ( - "id", - models.UUIDField( - default=uuid.uuid4, - editable=False, - primary_key=True, - serialize=False, - ), - ), - ("workflow_id", models.CharField(max_length=255)), - ("execution_id", models.CharField(max_length=255)), - ("adapter_instance_id", models.CharField(max_length=255)), - ("run_id", models.CharField(max_length=255)), - ("usage_type", models.CharField(max_length=255)), - ("model_name", models.CharField(max_length=255)), - ("embedding_tokens", models.IntegerField()), - ("prompt_tokens", models.IntegerField()), - ("completion_tokens", models.IntegerField()), - ("total_tokens", models.IntegerField()), - ], - options={ - "db_table": "token_usage", - }, - ), - ] diff --git a/backend/usage/migrations/0002_alter_usage_adapter_instance_id_and_more.py b/backend/usage/migrations/0002_alter_usage_adapter_instance_id_and_more.py deleted file mode 100644 index c18e4801a..000000000 --- a/backend/usage/migrations/0002_alter_usage_adapter_instance_id_and_more.py +++ /dev/null @@ -1,97 +0,0 @@ -# Generated by Django 4.2.1 on 2024-06-01 10:18 - -import uuid - -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ("usage", "0001_initial"), - ] - - operations = [ - migrations.AlterField( - model_name="usage", - name="adapter_instance_id", - field=models.CharField( - db_comment="Identifier for the adapter instance", max_length=255 - ), - ), - migrations.AlterField( - model_name="usage", - name="completion_tokens", - field=models.IntegerField( - db_comment="Number of tokens used for the completion" - ), - ), - migrations.AlterField( - model_name="usage", - name="embedding_tokens", - field=models.IntegerField(db_comment="Number of tokens used for embedding"), - ), - migrations.AlterField( - model_name="usage", - name="execution_id", - field=models.CharField( - blank=True, - db_comment="Identifier for the execution instance", - max_length=255, - null=True, - ), - ), - migrations.AlterField( - model_name="usage", - name="id", - field=models.UUIDField( - db_comment="Primary key for the usage entry, automatically generated UUID", - default=uuid.uuid4, - editable=False, - primary_key=True, - serialize=False, - ), - ), - migrations.AlterField( - model_name="usage", - name="model_name", - field=models.CharField(db_comment="Name of the model used", max_length=255), - ), - migrations.AlterField( - model_name="usage", - name="prompt_tokens", - field=models.IntegerField( - db_comment="Number of tokens used for the prompt" - ), - ), - migrations.AlterField( - model_name="usage", - name="run_id", - field=models.CharField( - blank=True, - db_comment="Identifier for the run", - max_length=255, - null=True, - ), - ), - migrations.AlterField( - model_name="usage", - name="total_tokens", - field=models.IntegerField(db_comment="Total number of tokens used"), - ), - migrations.AlterField( - model_name="usage", - name="usage_type", - field=models.CharField(db_comment="Type of usage", max_length=255), - ), - migrations.AlterField( - model_name="usage", - name="workflow_id", - field=models.CharField( - blank=True, - db_comment="Identifier for the workflow", - max_length=255, - null=True, - ), - ), - ] diff --git a/backend/usage/migrations/0003_usage_cost_in_dollars_usage_llm_usage_reason_and_more.py b/backend/usage/migrations/0003_usage_cost_in_dollars_usage_llm_usage_reason_and_more.py deleted file mode 100644 index 90f736184..000000000 --- a/backend/usage/migrations/0003_usage_cost_in_dollars_usage_llm_usage_reason_and_more.py +++ /dev/null @@ -1,49 +0,0 @@ -# Generated by Django 4.2.1 on 2024-06-27 08:08 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ("usage", "0002_alter_usage_adapter_instance_id_and_more"), - ] - - operations = [ - migrations.AddField( - model_name="usage", - name="cost_in_dollars", - field=models.FloatField( - db_comment="Total number of tokens used", default=0.0 - ), - preserve_default=False, - ), - migrations.AddField( - model_name="usage", - name="llm_usage_reason", - field=models.CharField( - blank=True, - choices=[ - ("extraction", "Extraction"), - ("challenge", "Challenge"), - ("summarize", "Summarize"), - ], - db_comment="Reason for LLM usage. Empty if usage_type is 'embedding'. ", - max_length=255, - null=True, - ), - ), - migrations.AlterField( - model_name="usage", - name="usage_type", - field=models.CharField( - choices=[("llm", "LLM Usage"), ("embedding", "Embedding Usage")], - db_comment="Type of usage, either 'llm' or 'embedding'", - max_length=255, - ), - ), - migrations.AddIndex( - model_name="usage", - index=models.Index(fields=["run_id"], name="token_usage_run_id_cd3578_idx"), - ), - ] diff --git a/backend/usage/migrations/__init__.py b/backend/usage/migrations/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/backend/usage/models.py b/backend/usage/models.py deleted file mode 100644 index 63ac10a47..000000000 --- a/backend/usage/models.py +++ /dev/null @@ -1,72 +0,0 @@ -import uuid - -from django.db import models -from utils.models.base_model import BaseModel - - -class UsageType(models.TextChoices): - LLM = "llm", "LLM Usage" - EMBEDDING = "embedding", "Embedding Usage" - - -class LLMUsageReason(models.TextChoices): - EXTRACTION = "extraction", "Extraction" - CHALLENGE = "challenge", "Challenge" - SUMMARIZE = "summarize", "Summarize" - - -class Usage(BaseModel): - id = models.UUIDField( - primary_key=True, - default=uuid.uuid4, - editable=False, - db_comment="Primary key for the usage entry, automatically generated UUID", - ) - workflow_id = models.CharField( - max_length=255, null=True, blank=True, db_comment="Identifier for the workflow" - ) - execution_id = models.CharField( - max_length=255, - null=True, - blank=True, - db_comment="Identifier for the execution instance", - ) - adapter_instance_id = models.CharField( - max_length=255, db_comment="Identifier for the adapter instance" - ) - run_id = models.CharField( - max_length=255, null=True, blank=True, db_comment="Identifier for the run" - ) - usage_type = models.CharField( - max_length=255, - choices=UsageType.choices, - db_comment="Type of usage, either 'llm' or 'embedding'", - ) - llm_usage_reason = models.CharField( - max_length=255, - choices=LLMUsageReason.choices, - null=True, - blank=True, - db_comment="Reason for LLM usage. Empty if usage_type is 'embedding'. ", - ) - model_name = models.CharField(max_length=255, db_comment="Name of the model used") - embedding_tokens = models.IntegerField( - db_comment="Number of tokens used for embedding" - ) - prompt_tokens = models.IntegerField( - db_comment="Number of tokens used for the prompt" - ) - completion_tokens = models.IntegerField( - db_comment="Number of tokens used for the completion" - ) - total_tokens = models.IntegerField(db_comment="Total number of tokens used") - cost_in_dollars = models.FloatField(db_comment="Total number of tokens used") - - def __str__(self): - return str(self.id) - - class Meta: - db_table = "token_usage" - indexes = [ - models.Index(fields=["run_id"]), - ] diff --git a/backend/usage/serializers.py b/backend/usage/serializers.py deleted file mode 100644 index eb1f2c326..000000000 --- a/backend/usage/serializers.py +++ /dev/null @@ -1,5 +0,0 @@ -from rest_framework import serializers - - -class GetUsageSerializer(serializers.Serializer): - run_id = serializers.CharField(required=True) diff --git a/backend/usage/tests.py b/backend/usage/tests.py deleted file mode 100644 index a39b155ac..000000000 --- a/backend/usage/tests.py +++ /dev/null @@ -1 +0,0 @@ -# Create your tests here. diff --git a/backend/usage/urls.py b/backend/usage/urls.py deleted file mode 100644 index 9c4fb95c7..000000000 --- a/backend/usage/urls.py +++ /dev/null @@ -1,16 +0,0 @@ -from django.urls import path -from rest_framework.urlpatterns import format_suffix_patterns - -from .views import UsageView - -get_token_usage = UsageView.as_view({"get": "get_token_usage"}) - -urlpatterns = format_suffix_patterns( - [ - path( - "get_token_usage/", - get_token_usage, - name="get-token-usage", - ), - ] -) diff --git a/backend/usage/views.py b/backend/usage/views.py deleted file mode 100644 index 66438fd4e..000000000 --- a/backend/usage/views.py +++ /dev/null @@ -1,48 +0,0 @@ -import logging - -from django.http import HttpRequest -from rest_framework import status, viewsets -from rest_framework.decorators import action -from rest_framework.response import Response - -from .constants import UsageKeys -from .helper import UsageHelper -from .serializers import GetUsageSerializer - -logger = logging.getLogger(__name__) - - -class UsageView(viewsets.ModelViewSet): - """Viewset for managing Usage-related operations.""" - - @action(detail=True, methods=["get"]) - def get_token_usage(self, request: HttpRequest) -> Response: - """Retrieves the aggregated token usage for a given run_id. - - This method validates the 'run_id' query parameter, aggregates the token - usage statistics for the specified run_id, and returns the results. - - Args: - request (HttpRequest): The HTTP request object containing the - query parameters. - - Returns: - Response: A Response object containing the aggregated token usage data - with HTTP 200 OK status if successful, or an error message and - appropriate HTTP status if an error occurs. - """ - - # Validate the query parameters using the serializer - # This ensures that 'run_id' is present and valid - serializer = GetUsageSerializer(data=self.request.query_params) - serializer.is_valid(raise_exception=True) - run_id = serializer.validated_data.get(UsageKeys.RUN_ID) - - # Retrieve aggregated token count for the given run_id. - result: dict = UsageHelper.get_aggregated_token_count(run_id=run_id) - - # Log the successful completion of the operation - logger.debug(f"Token usage retrieved successfully for run_id: {run_id}") - - # Return the result - return Response(status=status.HTTP_200_OK, data=result)