Skip to content

Commit

Permalink
Fix: Coverage count in Prompt studio (#727)
Browse files Browse the repository at this point in the history
* fixed coverage count

* added coverage detail to output

* added coverage count support on BE

* fixed coverage issue on FE

* removed unwanted conditional logic

* code refactor

* code clean up

* changed error to warning

* BE code refactor

* challenger fix

* handled errors properly

* fixed circular import error

* fixed coverage count

---------

Co-authored-by: Jaseem Jas <[email protected]>
  • Loading branch information
jagadeeswaran-zipstack and jaseemjaskp authored Dec 4, 2024
1 parent 2d29154 commit 4109251
Show file tree
Hide file tree
Showing 14 changed files with 193 additions and 27 deletions.
18 changes: 12 additions & 6 deletions backend/prompt_studio/prompt_studio_core_v2/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from prompt_studio.prompt_profile_manager_v2.models import ProfileManager
from prompt_studio.prompt_studio_core_v2.constants import ToolStudioKeys as TSKeys
from prompt_studio.prompt_studio_core_v2.exceptions import DefaultProfileError
from prompt_studio.prompt_studio_output_manager_v2.output_manager_util import (
OutputManagerUtils,
)
from prompt_studio.prompt_studio_v2.models import ToolStudioPrompt
from prompt_studio.prompt_studio_v2.serializers import ToolStudioPromptSerializer
from rest_framework import serializers
Expand Down Expand Up @@ -55,11 +58,10 @@ def to_representation(self, instance): # type: ignore
profile_manager = ProfileManager.get_default_llm_profile(instance)
data[TSKeys.DEFAULT_PROFILE] = profile_manager.profile_id
except DefaultProfileError:
logger.info(
logger.warning(
"Default LLM profile doesnt exist for prompt tool %s",
str(instance.tool_id),
)
try:
prompt_instance: ToolStudioPrompt = ToolStudioPrompt.objects.filter(
tool_id=data.get(TSKeys.TOOL_ID)
).order_by("sequence_number")
Expand All @@ -69,11 +71,15 @@ def to_representation(self, instance): # type: ignore
if prompt_instance.count() != 0:
for prompt in prompt_instance:
prompt_serializer = ToolStudioPromptSerializer(prompt)
output.append(prompt_serializer.data)
coverage = OutputManagerUtils.get_coverage(
data.get(TSKeys.TOOL_ID),
profile_manager.profile_id,
prompt.prompt_id,
)
serialized_data = prompt_serializer.data
serialized_data["coverage"] = coverage
output.append(serialized_data)
data[TSKeys.PROMPTS] = output
except Exception as e:
logger.error(f"Error occured while appending prompts {e}")
return data

data["created_by_email"] = instance.created_by.email

Expand Down
30 changes: 30 additions & 0 deletions backend/prompt_studio/prompt_studio_output_manager/serializers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging

from django.db.models import Count
from usage.helper import UsageHelper

from backend.serializers import AuditSerializer
Expand All @@ -26,6 +27,35 @@ def to_representation(self, instance):
)
token_usage = {}
data["token_usage"] = token_usage
# Get the coverage for the current tool_id and profile_manager_id
try:
# Fetch all relevant outputs for the current tool and profile
related_outputs = (
PromptStudioOutputManager.objects.filter(
tool_id=instance.tool_id,
profile_manager_id=instance.profile_manager_id,
prompt_id=instance.prompt_id,
)
.values("prompt_id", "profile_manager_id")
.annotate(document_count=Count("document_manager_id"))
)

coverage = {}
for output in related_outputs:
prompt_key = str(output["prompt_id"])
profile_key = str(output["profile_manager_id"])
coverage[f"coverage_{profile_key}_{prompt_key}"] = output[
"document_count"
]

data["coverage"] = coverage
except Exception as e:
logger.error(
"Error occurred while fetching "
f"coverage for tool_id {instance.tool_id} "
f"and profile_manager_id {instance.profile_manager_id}: {e}"
)
data["coverage"] = {}
# Convert string to list
try:
data["context"] = json.loads(data["context"])
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from django.db.models import Count
from prompt_studio.prompt_studio_output_manager_v2.models import (
PromptStudioOutputManager,
)


class OutputManagerUtils:
@staticmethod
def get_coverage(
tool_id: str, profile_manager_id: str, prompt_id: str = None
) -> dict[str, int]:
"""
Method to fetch coverage data for given tool and profile manager.
Args:
tool (CustomTool): The tool instance or ID for which coverage is fetched.
profile_manager_id (str): The ID of the profile manager
for which coverage is calculated.
prompt_id (Optional[str]): The ID of the prompt (optional).
If provided, coverage is fetched for the specific prompt.
Returns:
dict[str, int]: A dictionary containing coverage information.
Keys are formatted as "coverage_<prompt_id>_<profile_manager_id>".
Values are the count of documents associated with each prompt
and profile combination.
"""

prompt_outputs = (
PromptStudioOutputManager.objects.filter(
tool_id=tool_id,
profile_manager_id=profile_manager_id,
**({"prompt_id": prompt_id} if prompt_id else {}),
)
.values("prompt_id", "profile_manager_id")
.annotate(document_count=Count("document_manager_id"))
)

coverage = {}
for prompt_output in prompt_outputs:
prompt_key = str(prompt_output["prompt_id"])
profile_key = str(prompt_output["profile_manager_id"])
coverage[f"coverage_{prompt_key}_{profile_key}"] = prompt_output[
"document_count"
]
return coverage
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from backend.serializers import AuditSerializer

from .models import PromptStudioOutputManager
from .output_manager_util import OutputManagerUtils

logger = logging.getLogger(__name__)

Expand All @@ -16,16 +17,34 @@ class Meta:
fields = "__all__"

def to_representation(self, instance):

data = super().to_representation(instance)
try:
token_usage = UsageHelper.get_aggregated_token_count(instance.run_id)
except Exception as e:
logger.error(
logger.warning(
"Error occured while fetching token usage for run_id"
f"{instance.run_id}: {e}"
" | Process continued"
)
token_usage = {}
data["token_usage"] = token_usage
# Get the coverage for the current tool_id and profile_manager_id
try:
# Fetch all relevant outputs for the current tool and profile
coverage = OutputManagerUtils.get_coverage(
instance.tool_id, instance.profile_manager_id, instance.prompt_id
)
data["coverage"] = coverage

except Exception as e:
logger.error(
"Error occurred while fetching "
f"coverage for tool_id {instance.tool_id} "
f"and profile_manager_id {instance.profile_manager_id}: {e}"
" | Process continued"
)
data["coverage"] = {}
# Convert string to list
try:
data["context"] = json.loads(data["context"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,15 @@ function DocumentParser({
}) {
const [enforceTypeList, setEnforceTypeList] = useState([]);
const [updatedPromptsCopy, setUpdatedPromptsCopy] = useState({});
const [isChallenge, setIsChallenge] = useState(false);
const bottomRef = useRef(null);
const { details, isSimplePromptStudio, updateCustomTool, getDropdownItems } =
useCustomToolStore();
const {
details,
isSimplePromptStudio,
updateCustomTool,
getDropdownItems,
isChallengeEnabled,
} = useCustomToolStore();
const { sessionDetails } = useSessionStore();
const { setAlertDetails } = useAlertStore();
const axiosPrivate = useAxiosPrivate();
Expand All @@ -45,6 +51,7 @@ function DocumentParser({
return { value: outputTypeData[item] };
});
setEnforceTypeList(dropdownList1);
setIsChallenge(isChallengeEnabled);

return () => {
// Set the prompts with updated changes when the component is unmounted
Expand All @@ -63,6 +70,10 @@ function DocumentParser({
};
}, []);

useEffect(() => {
setIsChallenge(details.enable_challenge);
}, [details.enable_challenge]);

useEffect(() => {
if (scrollToBottom) {
// Scroll down to the lastest chat.
Expand Down Expand Up @@ -169,6 +180,29 @@ function DocumentParser({
return outputs;
};

const getPromptCoverageCount = (promptId) => {
const keys = Object.keys(promptOutputs || {});
const coverageKey = `coverage_${promptId}`;
const outputs = {};
if (!keys?.length) {
details?.prompts?.forEach((prompt) => {
if (prompt?.coverage) {
const key = Object.keys(prompt?.coverage)[0];
if (key?.startsWith(coverageKey)) {
outputs[key] = prompt?.coverage[key];
}
}
});
return outputs;
}
keys?.forEach((key) => {
if (key?.startsWith(coverageKey)) {
outputs[key] = promptOutputs[key];
}
});
return outputs;
};

if (!details?.prompts?.length) {
if (isSimplePromptStudio && SpsPromptsEmptyState) {
return <SpsPromptsEmptyState />;
Expand Down Expand Up @@ -196,6 +230,8 @@ function DocumentParser({
outputs={getPromptOutputs(item?.prompt_id)}
enforceTypeList={enforceTypeList}
setUpdatedPromptsCopy={setUpdatedPromptsCopy}
coverageCountData={getPromptCoverageCount(item?.prompt_id)}
isChallenge={isChallenge}
/>
<div ref={bottomRef} className="doc-parser-pad-bottom" />
</div>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import {
import { SpinnerLoader } from "../../widgets/spinner-loader/SpinnerLoader";
import { useAlertStore } from "../../../store/alert-store";
import { useExceptionHandler } from "../../../hooks/useExceptionHandler";
import { TokenUsage } from "../token-usage/TokenUsage";
import { useTokenUsageStore } from "../../../store/token-usage-store";
import { ProfileInfoBar } from "../profile-info-bar/ProfileInfoBar";

Expand All @@ -42,7 +41,7 @@ const outputStatus = {
fail: "FAIL",
};

const errorTypes = ["null", "undefined", "false"];
const errorTypes = [null, undefined, false];

function OutputForDocModal({
open,
Expand Down Expand Up @@ -243,6 +242,7 @@ function OutputForDocModal({
status = outputStatus.success;
message = displayPromptResult(output?.output, true);
}
const promptTokenUsage = output?.token_usage?.total_tokens;

if (output?.output === undefined) {
status = outputStatus.yet_to_process;
Expand All @@ -253,17 +253,7 @@ function OutputForDocModal({
const result = {
key: item?.document_id,
document: item?.document_name,
token_count: !singlePassExtractMode && (
<TokenUsage
tokenUsageId={
promptId +
"__" +
item?.document_id +
"__" +
(selectedProfile || profileManagerId)
}
/>
),
token_count: !singlePassExtractMode && (promptTokenUsage || "NA"),
value: (
<>
{isLoading ? (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,16 @@
padding-top: 12px;
position: relative;
background-color: #fff8e6;
max-height: 300px;
overflow-y: scroll;
scrollbar-color: inherit #fff8e6 !important;
display: flex;
flex-grow: 1;
}

.prompt-card-result::-webkit-scrollbar-track {
background-color: #fff8e6 !important;
}
.prompt-card-result .ant-typography {
margin-bottom: 0px;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import usePostHogEvents from "../../../hooks/usePostHogEvents";
import { PromptCardItems } from "./PromptCardItems";
import "./PromptCard.css";
import { handleUpdateStatus } from "./constants";
// import { usePromptRunStatusStore } from "../../../store/prompt-run-status-store";

const PromptCard = memo(
({
Expand All @@ -26,6 +25,8 @@ const PromptCard = memo(
setUpdatedPromptsCopy,
handlePromptRunRequest,
promptRunStatus,
coverageCountData,
isChallenge,
}) => {
const [promptDetailsState, setPromptDetailsState] = useState({});
const [isPromptDetailsStateUpdated, setIsPromptDetailsStateUpdated] =
Expand Down Expand Up @@ -258,6 +259,8 @@ const PromptCard = memo(
handleSpsLoading={handleSpsLoading}
promptOutputs={promptOutputs}
promptRunStatus={promptRunStatus}
coverageCountData={coverageCountData}
isChallenge={isChallenge}
/>
<OutputForDocModal
open={openOutputForDoc}
Expand All @@ -283,6 +286,8 @@ PromptCard.propTypes = {
setUpdatedPromptsCopy: PropTypes.func.isRequired,
handlePromptRunRequest: PropTypes.func.isRequired,
promptRunStatus: PropTypes.object.isRequired,
coverageCountData: PropTypes.object.isRequired,
isChallenge: PropTypes.bool.isRequired,
};

export { PromptCard };
Loading

0 comments on commit 4109251

Please sign in to comment.