Skip to content

Commit

Permalink
Merge branch 'main' into feat/FileEXecutionModelV1
Browse files Browse the repository at this point in the history
  • Loading branch information
muhammad-ali-e authored Dec 17, 2024
2 parents 65a1f57 + e75d51f commit d7ac1e9
Show file tree
Hide file tree
Showing 10 changed files with 77 additions and 26 deletions.
7 changes: 6 additions & 1 deletion backend/prompt_studio/prompt_studio_core_v2/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,16 @@ def to_representation(self, instance): # type: ignore
# Appending prompt instances of the tool for FE Processing
if prompt_instance.count() != 0:
for prompt in prompt_instance:
profile_manager_id = prompt.prompt_id
if instance.single_pass_extraction_mode:
# use projects default profile
profile_manager_id = profile_manager.profile_id
prompt_serializer = ToolStudioPromptSerializer(prompt)
coverage = OutputManagerUtils.get_coverage(
data.get(TSKeys.TOOL_ID),
prompt.profile_manager_id,
profile_manager_id,
prompt.prompt_id,
instance.single_pass_extraction_mode,
)
serialized_data = prompt_serializer.data
serialized_data["coverage"] = coverage
Expand Down
8 changes: 8 additions & 0 deletions backend/prompt_studio/prompt_studio_core_v2/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
PromptStudioRegistryInfoSerializer,
)
from prompt_studio.prompt_studio_v2.constants import ToolStudioPromptErrors
from prompt_studio.prompt_studio_v2.models import ToolStudioPrompt
from prompt_studio.prompt_studio_v2.serializers import ToolStudioPromptSerializer
from rest_framework import status, viewsets
from rest_framework.decorators import action
Expand Down Expand Up @@ -380,6 +381,13 @@ def create_profile_manager(self, request: HttpRequest, pk: Any = None) -> Respon
raise MaxProfilesReachedError()
try:
self.perform_create(serializer)
# Check if this is the first profile and make it default for all prompts
if profile_count == 0:
profile_manager = serializer.instance # Newly created profile manager
ToolStudioPrompt.objects.filter(tool_id=prompt_studio_tool).update(
profile_manager=profile_manager
)

except IntegrityError:
raise DuplicateData(
f"{ProfileManagerErrors.PROFILE_NAME_EXISTS}, \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,14 +191,17 @@ def fetch_default_llm_profile(tool: CustomTool) -> ProfileManager:

@staticmethod
def fetch_default_output_response(
tool_studio_prompts: list[ToolStudioPrompt], document_manager_id: str
tool_studio_prompts: list[ToolStudioPrompt],
document_manager_id: str,
use_default_profile: bool = False,
) -> dict[str, Any]:
"""Method to frame JSON responses for combined output for default for
default profile manager of the project.
Args:
tool_studio_prompts (list[ToolStudioPrompt])
document_manager_id (str)
use_default_profile (bool)
Returns:
dict[str, Any]: Formatted JSON response for combined output.
Expand All @@ -213,10 +216,16 @@ def fetch_default_output_response(
profile_manager_id = tool_prompt.profile_manager_id

# If profile_manager is not set, skip this record
if not profile_manager_id:
if not profile_manager_id and not use_default_profile:
result[tool_prompt.prompt_key] = ""
continue

if not profile_manager_id:
default_profile = ProfileManager.get_default_llm_profile(
tool_prompt.tool_id
)
profile_manager_id = default_profile.profile_id

try:
queryset = PromptStudioOutputManager.objects.filter(
prompt_id=prompt_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
class OutputManagerUtils:
@staticmethod
def get_coverage(
tool_id: str, profile_manager_id: str, prompt_id: str = None
tool_id: str,
profile_manager_id: str,
prompt_id: str = None,
is_single_pass: bool = False,
) -> dict[str, int]:
"""
Method to fetch coverage data for given tool and profile manager.
Expand All @@ -17,6 +20,7 @@ def get_coverage(
profile_manager_id (str): The ID of the profile manager
for which coverage is calculated.
prompt_id (Optional[str]): The ID of the prompt (optional).
is_single_pass (Optional[bool]): Singlepass enabled or not
If provided, coverage is fetched for the specific prompt.
Returns:
Expand All @@ -25,12 +29,13 @@ def get_coverage(
Values are the count of documents associated with each prompt
and profile combination.
"""

# TODO: remove singlepass reference
prompt_outputs = (
PromptStudioOutputManager.objects.filter(
tool_id=tool_id,
profile_manager_id=profile_manager_id,
**({"prompt_id": prompt_id} if prompt_id else {}),
prompt_id=prompt_id,
is_single_pass_extract=is_single_pass,
)
.values("prompt_id", "profile_manager_id")
.annotate(document_count=Count("document_manager_id"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ def to_representation(self, instance):
try:
# Fetch all relevant outputs for the current tool and profile
coverage = OutputManagerUtils.get_coverage(
instance.tool_id, instance.profile_manager_id, instance.prompt_id
instance.tool_id,
instance.profile_manager_id,
instance.prompt_id,
instance.is_single_pass_extract,
)
data["coverage"] = coverage

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def get_output_for_tool_default(self, request: HttpRequest) -> Response:
result: dict[str, Any] = OutputManagerHelper.fetch_default_output_response(
tool_studio_prompts=tool_studio_prompts,
document_manager_id=document_manager_id,
use_default_profile=True,
)

return Response(result, status=status.HTTP_200_OK)
9 changes: 7 additions & 2 deletions backend/workflow_manager/workflow_v2/workflow_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,8 +459,9 @@ def _set_result_acknowledge(execution: WorkflowExecution) -> None:
f"ExecutionID [{execution.id}] - Task {execution.task_id} acknowledged"
)

@staticmethod
@classmethod
def execute_workflow_async(
cls,
workflow_id: str,
execution_id: str,
hash_values_of_files: dict[str, FileHash],
Expand Down Expand Up @@ -489,7 +490,7 @@ def execute_workflow_async(
}
org_schema = UserContext.get_organization_identifier()
log_events_id = StateStore.get(Common.LOG_EVENTS_ID)
async_execution = WorkflowHelper.execute_bin.apply_async(
async_execution: AsyncResult = cls.execute_bin.apply_async(
args=[
org_schema, # schema_name
workflow_id, # workflow_id
Expand Down Expand Up @@ -524,6 +525,10 @@ def execute_workflow_async(
workflow_execution.status,
result=task_result,
)
# If task is complete, handle acknowledgment and forgetting the
if async_execution.ready():
async_execution.forget() # Remove the result from the result backend.
cls._set_result_acknowledge(workflow_execution)
return execution_response
except celery_exceptions.TimeoutError:
return ExecutionResponse(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,22 @@ function ManageLlmProfiles() {
handleConfirm={() => handleDelete(item?.profile_id)}
content="The LLM profile will be permanently deleted."
>
<Button
size="small"
className="display-flex-align-center"
disabled={isPublicSource}
<Tooltip
title={
defaultLlmProfile === item?.profile_id &&
"Default profile cannot be deleted"
}
>
<DeleteOutlined classID="manage-llm-pro-icon" />
</Button>
<Button
size="small"
className="display-flex-align-center"
disabled={
isPublicSource || defaultLlmProfile === item?.profile_id
}
>
<DeleteOutlined classID="manage-llm-pro-icon" />
</Button>
</Tooltip>
</ConfirmModal>
),
edit: (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ function OutputForDocModal({
llmProfiles,
defaultLlmProfile,
} = useCustomToolStore();
const [selectedProfile, setSelectedProfile] = useState(defaultLlmProfile);

const profileId = singlePassExtractMode
? defaultLlmProfile
: profileManagerId;
const [selectedProfile, setSelectedProfile] = useState(profileId);
const { sessionDetails } = useSessionStore();
const axiosPrivate = useAxiosPrivate();
const navigate = useNavigate();
Expand All @@ -79,7 +83,7 @@ function OutputForDocModal({
if (!open) {
return;
}
handleGetOutputForDocs(selectedProfile || profileManagerId);
handleGetOutputForDocs(selectedProfile);
getAdapterInfo();
}, [
open,
Expand All @@ -88,6 +92,10 @@ function OutputForDocModal({
isSinglePassExtractLoading,
]);

useEffect(() => {
setSelectedProfile(profileId);
}, [profileManagerId, singlePassExtractMode]);

useEffect(() => {
handleRowsGeneration(promptOutputs);
}, [promptOutputs, tokenUsage]);
Expand Down Expand Up @@ -284,7 +292,7 @@ function OutputForDocModal({

const handleTabChange = (key) => {
if (key === "0") {
setSelectedProfile(defaultLlmProfile);
setSelectedProfile(profileId);
} else {
setSelectedProfile(adapterData[key - 1]?.profile_id);
}
Expand Down Expand Up @@ -337,10 +345,7 @@ function OutputForDocModal({
></TabPane>
))}
</Tabs>{" "}
<ProfileInfoBar
profileId={selectedProfile || profileManagerId}
profiles={llmProfiles}
/>
<ProfileInfoBar profileId={selectedProfile} profiles={llmProfiles} />
</div>
<div className="display-flex-right">
<Button
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ function PromptCardItems({
isPublicSource,
adapters,
defaultLlmProfile,
singlePassExtractMode,
} = useCustomToolStore();
const [isEditingPrompt, setIsEditingPrompt] = useState(false);
const [isEditingTitle, setIsEditingTitle] = useState(false);
Expand All @@ -77,10 +78,10 @@ function PromptCardItems({
const isNotSingleLlmProfile = llmProfiles.length > 1;
const divRef = useRef(null);
const [enforceType, setEnforceType] = useState("");
const coverageKey = generateCoverageKey(
promptDetails?.prompt_id,
selectedLlmProfileId || defaultLlmProfile
);
const profileId = singlePassExtractMode
? defaultLlmProfile
: selectedLlmProfileId || defaultLlmProfile;
const coverageKey = generateCoverageKey(promptDetails?.prompt_id, profileId);

useEffect(() => {
if (enforceType !== promptDetails?.enforce_type) {
Expand Down

0 comments on commit d7ac1e9

Please sign in to comment.