From 5501972e2daec3efd5153125a13e44fc5893476b Mon Sep 17 00:00:00 2001 From: jagadeeswaran-zipstack Date: Thu, 19 Dec 2024 17:59:48 +0530 Subject: [PATCH] [Fix] Prompt studio Coverage (#907) * fixes for prompt sudio coverage * fixed prompt studio local variable issue * updated return type * added optional chaining * added missing type * added comment for error suppression * handled edge cases --- .../prompt_studio_core_v2/serializers.py | 63 +++++++++++++------ .../output_manager_util.py | 35 ++++------- .../document-parser/DocumentParser.jsx | 25 +------- .../manage-docs-modal/ManageDocsModal.jsx | 43 ++++++++++--- .../prompt-card/PromptCardItems.jsx | 22 +++++-- frontend/src/hooks/usePromptOutput.js | 5 +- 6 files changed, 110 insertions(+), 83 deletions(-) diff --git a/backend/prompt_studio/prompt_studio_core_v2/serializers.py b/backend/prompt_studio/prompt_studio_core_v2/serializers.py index bd9e32262..d6e79483a 100644 --- a/backend/prompt_studio/prompt_studio_core_v2/serializers.py +++ b/backend/prompt_studio/prompt_studio_core_v2/serializers.py @@ -44,48 +44,71 @@ class Meta: def to_representation(self, instance): # type: ignore data = super().to_representation(instance) + default_profile = None + + # Fetch summarize LLM profile try: - profile_manager = ProfileManager.objects.get( + summarize_profile = ProfileManager.objects.get( prompt_studio_tool=instance, is_summarize_llm=True ) - data[TSKeys.SUMMARIZE_LLM_PROFILE] = profile_manager.profile_id + data[TSKeys.SUMMARIZE_LLM_PROFILE] = summarize_profile.profile_id except ObjectDoesNotExist: logger.info( - "Summarize LLM profile doesnt exist for prompt tool %s", + "Summarize LLM profile doesn't exist for prompt tool %s", str(instance.tool_id), ) + + # Fetch default LLM profile try: - profile_manager = ProfileManager.get_default_llm_profile(instance) - data[TSKeys.DEFAULT_PROFILE] = profile_manager.profile_id + default_profile = ProfileManager.get_default_llm_profile(instance) + data[TSKeys.DEFAULT_PROFILE] = default_profile.profile_id except DefaultProfileError: + # To make it compatible with older projects error suppressed with warning. logger.warning( - "Default LLM profile doesnt exist for prompt tool %s", + "Default LLM profile doesn't exist for prompt tool %s", str(instance.tool_id), ) - prompt_instance: ToolStudioPrompt = ToolStudioPrompt.objects.filter( + + # Fetch prompt instances + prompt_instances: ToolStudioPrompt = ToolStudioPrompt.objects.filter( tool_id=data.get(TSKeys.TOOL_ID) ).order_by("sequence_number") - data[TSKeys.PROMPTS] = [] + + if not prompt_instances.exists(): + data[TSKeys.PROMPTS] = [] + return data + + # Process prompt instances output: list[Any] = [] - # Appending prompt instances of the tool for FE Processing - if prompt_instance.count() != 0: - for prompt in prompt_instance: - profile_manager_id = prompt.prompt_id - if instance.single_pass_extraction_mode: - # use projects default profile - profile_manager_id = profile_manager.profile_id - prompt_serializer = ToolStudioPromptSerializer(prompt) + for prompt in prompt_instances: + prompt_serializer = ToolStudioPromptSerializer(prompt) + serialized_data = prompt_serializer.data + + # Determine coverage + coverage: list[Any] = [] + profile_manager_id = prompt.profile_manager + if default_profile and instance.single_pass_extraction_mode: + profile_manager_id = default_profile.profile_id + + if profile_manager_id: coverage = OutputManagerUtils.get_coverage( data.get(TSKeys.TOOL_ID), profile_manager_id, prompt.prompt_id, instance.single_pass_extraction_mode, ) - serialized_data = prompt_serializer.data - serialized_data["coverage"] = coverage - output.append(serialized_data) - data[TSKeys.PROMPTS] = output + else: + logger.info( + "Skipping coverage calculation for prompt %s " + "due to missing profile ID", + str(prompt.prompt_key), + ) + + # Add coverage to serialized data + serialized_data["coverage"] = coverage + output.append(serialized_data) + data[TSKeys.PROMPTS] = output data["created_by_email"] = instance.created_by.email return data diff --git a/backend/prompt_studio/prompt_studio_output_manager_v2/output_manager_util.py b/backend/prompt_studio/prompt_studio_output_manager_v2/output_manager_util.py index 4a0352099..b5b6a957d 100644 --- a/backend/prompt_studio/prompt_studio_output_manager_v2/output_manager_util.py +++ b/backend/prompt_studio/prompt_studio_output_manager_v2/output_manager_util.py @@ -1,4 +1,3 @@ -from django.db.models import Count from prompt_studio.prompt_studio_output_manager_v2.models import ( PromptStudioOutputManager, ) @@ -11,41 +10,33 @@ def get_coverage( profile_manager_id: str, prompt_id: str = None, is_single_pass: bool = False, - ) -> dict[str, int]: + ) -> list[str]: """ Method to fetch coverage data for given tool and profile manager. Args: - tool (CustomTool): The tool instance or ID for which coverage is fetched. + tool_id (str): The ID of the tool for which coverage is fetched. profile_manager_id (str): The ID of the profile manager for which coverage is calculated. prompt_id (Optional[str]): The ID of the prompt (optional). - is_single_pass (Optional[bool]): Singlepass enabled or not + is_single_pass (Optional[bool]): Singlepass enabled or not. If provided, coverage is fetched for the specific prompt. Returns: - dict[str, int]: A dictionary containing coverage information. + dict[str, list[str]]: A dictionary containing coverage information. Keys are formatted as "coverage__". - Values are the count of documents associated with each prompt + Values are lists of document IDs associated with each prompt and profile combination. """ # TODO: remove singlepass reference - prompt_outputs = ( - PromptStudioOutputManager.objects.filter( - tool_id=tool_id, - profile_manager_id=profile_manager_id, - prompt_id=prompt_id, - is_single_pass_extract=is_single_pass, - ) - .values("prompt_id", "profile_manager_id") - .annotate(document_count=Count("document_manager_id")) - ) + prompt_outputs = PromptStudioOutputManager.objects.filter( + tool_id=tool_id, + profile_manager_id=profile_manager_id, + prompt_id=prompt_id, + is_single_pass_extract=is_single_pass, + ).values("prompt_id", "profile_manager_id", "document_manager_id") - coverage = {} + coverage = [] for prompt_output in prompt_outputs: - prompt_key = str(prompt_output["prompt_id"]) - profile_key = str(prompt_output["profile_manager_id"]) - coverage[f"coverage_{prompt_key}_{profile_key}"] = prompt_output[ - "document_count" - ] + coverage.append(str(prompt_output["document_manager_id"])) return coverage diff --git a/frontend/src/components/custom-tools/document-parser/DocumentParser.jsx b/frontend/src/components/custom-tools/document-parser/DocumentParser.jsx index b075c7e4e..43f64d111 100644 --- a/frontend/src/components/custom-tools/document-parser/DocumentParser.jsx +++ b/frontend/src/components/custom-tools/document-parser/DocumentParser.jsx @@ -180,29 +180,6 @@ function DocumentParser({ return outputs; }; - const getPromptCoverageCount = (promptId) => { - const keys = Object.keys(promptOutputs || {}); - const coverageKey = `coverage_${promptId}`; - const outputs = {}; - if (!keys?.length) { - details?.prompts?.forEach((prompt) => { - if (prompt?.coverage) { - const key = Object.keys(prompt?.coverage)[0]; - if (key?.startsWith(coverageKey)) { - outputs[key] = prompt?.coverage[key]; - } - } - }); - return outputs; - } - keys?.forEach((key) => { - if (key?.startsWith(coverageKey)) { - outputs[key] = promptOutputs[key]; - } - }); - return outputs; - }; - if (!details?.prompts?.length) { if (isSimplePromptStudio && SpsPromptsEmptyState) { return ; @@ -230,7 +207,7 @@ function DocumentParser({ outputs={getPromptOutputs(item?.prompt_id)} enforceTypeList={enforceTypeList} setUpdatedPromptsCopy={setUpdatedPromptsCopy} - coverageCountData={getPromptCoverageCount(item?.prompt_id)} + coverageCountData={item?.coverage} isChallenge={isChallenge} />
diff --git a/frontend/src/components/custom-tools/manage-docs-modal/ManageDocsModal.jsx b/frontend/src/components/custom-tools/manage-docs-modal/ManageDocsModal.jsx index 3d9d53e7d..d563e8a1e 100644 --- a/frontend/src/components/custom-tools/manage-docs-modal/ManageDocsModal.jsx +++ b/frontend/src/components/custom-tools/manage-docs-modal/ManageDocsModal.jsx @@ -33,6 +33,7 @@ import SpaceWrapper from "../../widgets/space-wrapper/SpaceWrapper"; import { SpinnerLoader } from "../../widgets/spinner-loader/SpinnerLoader"; import "./ManageDocsModal.css"; import usePostHogEvents from "../../../hooks/usePostHogEvents"; +import { usePromptOutputStore } from "../../../store/prompt-output-store"; let SummarizeStatusTitle = null; try { @@ -90,6 +91,7 @@ function ManageDocsModal({ const axiosPrivate = useAxiosPrivate(); const handleException = useExceptionHandler(); const { setPostHogCustomEvent } = usePostHogEvents(); + const { promptOutputs, updatePromptOutput } = usePromptOutputStore(); const successIndex = ( @@ -543,21 +545,48 @@ function ManageDocsModal({ ); updateCustomTool({ listOfDocs: newListOfDocs }); - if (newListOfDocs?.length === 1 && selectedDoc?.document_id !== docId) { - const doc = newListOfDocs[1]; + if (selectedDoc?.document_id === docId) { + const doc = newListOfDocs[0]; handleDocChange(doc); } - - if (docId === selectedDoc?.document_id) { - updateCustomTool({ selectedDoc: "" }); - handleUpdateTool({ output: "" }); - } + const updatedPromptDetails = removeIdFromCoverage(details, docId); + const updatedPromptOutput = removeIdFromCoverageOfPromptOutput( + promptOutputs, + docId + ); + updateCustomTool({ details: updatedPromptDetails }); + updatePromptOutput(updatedPromptOutput); }) .catch((err) => { setAlertDetails(handleException(err, "Failed to delete")); }); }; + const removeIdFromCoverage = (data, idToRemove) => { + if (data.prompts && Array.isArray(data.prompts)) { + data.prompts.forEach((prompt) => { + if (Array.isArray(prompt.coverage)) { + prompt.coverage = prompt.coverage.filter((id) => id !== idToRemove); + } + }); + } + return data; // Return the updated data + }; + + const removeIdFromCoverageOfPromptOutput = (data, idToRemove) => { + return Object.entries(data).reduce((updatedData, [key, value]) => { + // Create a new object for the current entry + updatedData[key] = { + ...value, + // Update the coverage array if it exists + coverage: value?.coverage + ? value?.coverage?.filter((id) => id !== idToRemove) + : value?.coverage, + }; + return updatedData; + }, {}); + }; + return ( 1; const divRef = useRef(null); const [enforceType, setEnforceType] = useState(""); - const profileId = singlePassExtractMode - ? defaultLlmProfile - : selectedLlmProfileId || defaultLlmProfile; - const coverageKey = generateCoverageKey(promptDetails?.prompt_id, profileId); + const promptId = promptDetails?.prompt_id; + const docId = selectedDoc?.document_id; + const promptProfile = promptDetails?.profile_manager || defaultLlmProfile; + const promptOutputKey = generatePromptOutputKey( + promptId, + docId, + promptProfile, + singlePassExtractMode, + true + ); + const promptCoverage = + promptOutputs[promptOutputKey]?.coverage || coverageCountData; useEffect(() => { if (enforceType !== promptDetails?.enforce_type) { @@ -213,7 +223,7 @@ function PromptCardItems({ )} - Coverage: {coverageCountData[coverageKey] || 0} of{" "} + Coverage: {promptCoverage?.length || 0} of{" "} {listOfDocs?.length || 0} docs diff --git a/frontend/src/hooks/usePromptOutput.js b/frontend/src/hooks/usePromptOutput.js index ad32591e3..5517ec968 100644 --- a/frontend/src/hooks/usePromptOutput.js +++ b/frontend/src/hooks/usePromptOutput.js @@ -91,7 +91,6 @@ const usePromptOutput = () => { let isTokenUsageForSinglePassAdded = false; const tokenUsageDetails = {}; - data.forEach((item) => { const promptId = item?.prompt_id; const docId = item?.document_manager; @@ -109,7 +108,6 @@ const usePromptOutput = () => { isSinglePass, true ); - const coverageKey = `coverage_${item?.prompt_id}_${llmProfile}`; outputs[key] = { runId: item?.run_id, promptOutputId: item?.prompt_output_id, @@ -119,8 +117,8 @@ const usePromptOutput = () => { tokenUsage: item?.token_usage, output: item?.output, timer, + coverage: item?.coverage, }; - outputs[coverageKey] = item?.coverage[coverageKey] || 0; if (item?.is_single_pass_extract && isTokenUsageForSinglePassAdded) return; @@ -150,7 +148,6 @@ const usePromptOutput = () => { ); tokenUsageDetails[tokenUsageId] = item?.token_usage; }); - if (isReset) { setPromptOutput(outputs); setTokenUsage(tokenUsageDetails);