diff --git a/backend/prompt_studio/prompt_studio_core_v2/constants.py b/backend/prompt_studio/prompt_studio_core_v2/constants.py index cb335b90f..2c6a80ac6 100644 --- a/backend/prompt_studio/prompt_studio_core_v2/constants.py +++ b/backend/prompt_studio/prompt_studio_core_v2/constants.py @@ -96,6 +96,7 @@ class ToolStudioPromptKeys: RECORD = "record" FILE_PATH = "file_path" ENABLE_HIGHLIGHT = "enable_highlight" + EXECUTION_SOURCE = "execution_source" class FileViewTypes: @@ -132,3 +133,15 @@ class DefaultPrompts: "Do not include any explanation in the reply. " "Only include the extracted information in the reply." ) + + +class ExecutionSource(Enum): + """Enum to indicate the source of invocation. + Any new sources can be added to this enum. + This is to indicate the prompt service. + + Args: + Enum (_type_): ide/tool + """ + + IDE = "ide" diff --git a/backend/prompt_studio/prompt_studio_core_v2/prompt_studio_helper.py b/backend/prompt_studio/prompt_studio_core_v2/prompt_studio_helper.py index 6bcacf340..7985173a1 100644 --- a/backend/prompt_studio/prompt_studio_core_v2/prompt_studio_helper.py +++ b/backend/prompt_studio/prompt_studio_core_v2/prompt_studio_helper.py @@ -19,7 +19,11 @@ from prompt_studio.prompt_profile_manager_v2.profile_manager_helper import ( ProfileManagerHelper, ) -from prompt_studio.prompt_studio_core_v2.constants import IndexingStatus, LogLevels +from prompt_studio.prompt_studio_core_v2.constants import ( + ExecutionSource, + IndexingStatus, + LogLevels, +) from prompt_studio.prompt_studio_core_v2.constants import ( ToolStudioPromptKeys as TSPKeys, ) @@ -1176,6 +1180,7 @@ def _fetch_single_pass_response( TSPKeys.FILE_HASH: file_hash, TSPKeys.FILE_NAME: doc_name, Common.LOG_EVENTS_ID: StateStore.get(Common.LOG_EVENTS_ID), + TSPKeys.EXECUTION_SOURCE: ExecutionSource.IDE.value, } util = PromptIdeBaseTool(log_level=LogLevel.INFO, org_id=org_id) diff --git a/prompt-service/src/unstract/prompt_service/constants.py b/prompt-service/src/unstract/prompt_service/constants.py index 9e6af6f9d..4f30dd11d 100644 --- a/prompt-service/src/unstract/prompt_service/constants.py +++ b/prompt-service/src/unstract/prompt_service/constants.py @@ -72,6 +72,7 @@ class PromptServiceContants: FILE_PATH = "file_path" HIGHLIGHT_DATA = "highlight_data" CONFIDENCE_DATA = "confidence_data" + EXECUTION_SOURCE = "execution_source" METRICS = "metrics" @@ -101,3 +102,20 @@ class DBTableV2: PROMPT_STUDIO_REGISTRY = "prompt_studio_registry" PLATFORM_KEY = "platform_key" TOKEN_USAGE = "usage" + + +class FileStorageKeys: + FILE_STORAGE_PROVIDER = "FILE_STORAGE_PROVIDER" + FILE_STORAGE_CREDENTIALS = "FILE_STORAGE_CREDENTIALS" + PERMANENT_REMOTE_STORAGE = "PERMANENT_REMOTE_STORAGE" + TEMPORARY_REMOTE_STORAGE = "TEMPORARY_REMOTE_STORAGE" + + +class FileStorageType(Enum): + PERMANENT = "permanent" + TEMPORARY = "temporary" + + +class ExecutionSource(Enum): + IDE = "ide" + TOOL = "tool" diff --git a/prompt-service/src/unstract/prompt_service/helper.py b/prompt-service/src/unstract/prompt_service/helper.py index 62a24f2d1..126a41c93 100644 --- a/prompt-service/src/unstract/prompt_service/helper.py +++ b/prompt-service/src/unstract/prompt_service/helper.py @@ -7,7 +7,12 @@ from dotenv import load_dotenv from flask import Flask, current_app from unstract.prompt_service.config import db -from unstract.prompt_service.constants import DBTableV2 +from unstract.prompt_service.constants import ( + DBTableV2, + ExecutionSource, + FeatureFlag, + FileStorageKeys, +) from unstract.prompt_service.constants import PromptServiceContants as PSKeys from unstract.prompt_service.db_utils import DBUtils from unstract.prompt_service.env_manager import EnvLoader @@ -16,6 +21,13 @@ from unstract.sdk.exceptions import SdkError from unstract.sdk.llm import LLM +from unstract.flags.src.unstract.flags.feature_flag import check_feature_flag_status + +if check_feature_flag_status(FeatureFlag.REMOTE_FILE_STORAGE): + from unstract.sdk.file_storage import FileStorage, FileStorageProvider + from unstract.sdk.file_storage.constants import StorageType + from unstract.sdk.file_storage.env_helper import EnvHelper + load_dotenv() # Global variable to store plugins @@ -278,6 +290,7 @@ def run_completion( prompt_type: Optional[str] = PSKeys.TEXT, enable_highlight: bool = False, file_path: str = "", + execution_source: Optional[str] = None, ) -> str: logger: Logger = current_app.logger try: @@ -286,9 +299,27 @@ def run_completion( ) highlight_data = None if highlight_data_plugin and enable_highlight: - highlight_data = highlight_data_plugin["entrypoint_cls"]( - logger=current_app.logger, file_path=file_path - ).run + if check_feature_flag_status(FeatureFlag.REMOTE_FILE_STORAGE): + fs_instance: FileStorage = FileStorage(FileStorageProvider.LOCAL) + if execution_source == ExecutionSource.IDE.value: + fs_instance = EnvHelper.get_storage( + storage_type=StorageType.PERMANENT, + env_name=FileStorageKeys.PERMANENT_REMOTE_STORAGE, + ) + if execution_source == ExecutionSource.TOOL.value: + fs_instance = EnvHelper.get_storage( + storage_type=StorageType.TEMPORARY, + env_name=FileStorageKeys.TEMPORARY_REMOTE_STORAGE, + ) + highlight_data = highlight_data_plugin["entrypoint_cls"]( + logger=current_app.logger, + file_path=file_path, + fs_instance=fs_instance, + ).run + else: + highlight_data = highlight_data_plugin["entrypoint_cls"]( + logger=current_app.logger, file_path=file_path + ).run completion = llm.complete( prompt=prompt, process_text=highlight_data, @@ -325,6 +356,7 @@ def extract_table( structured_output: dict[str, Any], llm: LLM, enforce_type: str, + execution_source: str, ) -> dict[str, Any]: table_settings = output[PSKeys.TABLE_SETTINGS] table_extractor: dict[str, Any] = plugins.get("table-extractor", {}) @@ -333,10 +365,32 @@ def extract_table( "Unable to extract table details. " "Please contact admin to resolve this issue." ) + if check_feature_flag_status(FeatureFlag.REMOTE_FILE_STORAGE): + fs_instance: FileStorage = FileStorage(FileStorageProvider.LOCAL) + if execution_source == ExecutionSource.IDE.value: + fs_instance = EnvHelper.get_storage( + storage_type=StorageType.PERMANENT, + env_name=FileStorageKeys.PERMANENT_REMOTE_STORAGE, + ) + if execution_source == ExecutionSource.TOOL.value: + fs_instance = EnvHelper.get_storage( + storage_type=StorageType.TEMPORARY, + env_name=FileStorageKeys.TEMPORARY_REMOTE_STORAGE, + ) try: - answer = table_extractor["entrypoint_cls"].extract_large_table( - llm=llm, table_settings=table_settings, enforce_type=enforce_type - ) + if check_feature_flag_status(FeatureFlag.REMOTE_FILE_STORAGE): + answer = table_extractor["entrypoint_cls"].extract_large_table( + llm=llm, + table_settings=table_settings, + enforce_type=enforce_type, + fs_instance=fs_instance, + ) + else: + answer = table_extractor["entrypoint_cls"].extract_large_table( + llm=llm, + table_settings=table_settings, + enforce_type=enforce_type, + ) structured_output[output[PSKeys.NAME]] = answer # We do not support summary and eval for table. # Hence returning the result diff --git a/prompt-service/src/unstract/prompt_service/main.py b/prompt-service/src/unstract/prompt_service/main.py index 91f0b5d2c..7dbe6dca2 100644 --- a/prompt-service/src/unstract/prompt_service/main.py +++ b/prompt-service/src/unstract/prompt_service/main.py @@ -111,6 +111,8 @@ def prompt_processor() -> Any: } metrics: dict = {} variable_names: list[str] = [] + # Identifier for source of invocation + execution_source = payload.get(PSKeys.EXECUTION_SOURCE, "") publish_log( log_events_id, {"tool_id": tool_id, "run_id": run_id, "doc_name": doc_name}, @@ -226,6 +228,7 @@ def prompt_processor() -> Any: structured_output=structured_output, llm=llm, enforce_type=output[PSKeys.TYPE], + execution_source=execution_source, ) metadata = query_usage_metadata(token=platform_key, metadata=metadata) response = { diff --git a/tools/structure/src/constants.py b/tools/structure/src/constants.py index 4ea7e6b7c..8f77d9ed3 100644 --- a/tools/structure/src/constants.py +++ b/tools/structure/src/constants.py @@ -75,5 +75,7 @@ class SettingsKeys: CONFIDENCE_DATA = "confidence_data" EXECUTION_RUN_DATA_FOLDER = "EXECUTION_RUN_DATA_FOLDER" FILE_PATH = "file_path" + EXECUTION_SOURCE = "execution_source" + TOOL = "tool" METRICS = "metrics" INDEXING = "indexing" diff --git a/tools/structure/src/main.py b/tools/structure/src/main.py index c55b76a11..6d1b16287 100644 --- a/tools/structure/src/main.py +++ b/tools/structure/src/main.py @@ -120,6 +120,7 @@ def run( SettingsKeys.FILE_HASH: file_hash, SettingsKeys.FILE_NAME: file_name, SettingsKeys.FILE_PATH: extracted_input_file, + SettingsKeys.EXECUTION_SOURCE: SettingsKeys.TOOL, } # TODO: Need to split extraction and indexing # to avoid unwanted indexing