Skip to content

Commit

Permalink
fix: DIA-1489: Multiskill prompt autorefinement fixes (#237)
Browse files Browse the repository at this point in the history
Co-authored-by: nik <[email protected]>
Co-authored-by: niklub <[email protected]>
  • Loading branch information
3 people authored Oct 29, 2024
1 parent 0c667eb commit c34fd15
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 152 deletions.
10 changes: 8 additions & 2 deletions adala/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,8 @@ async def arefine_skill(
skill_name: str,
input_variables: List[str],
data: Optional[List[Dict]] = None,
reapply: bool = False,
instructions: Optional[str] = None,
) -> ImprovedPromptResponse:
"""
beta v2 of Agent.learn() that is:
Expand Down Expand Up @@ -429,12 +431,16 @@ async def arefine_skill(
predictions = None
else:
inputs = InternalDataFrame.from_records(data or [])
predictions = await self.skills.aapply(inputs, runtime=runtime)

if reapply:
predictions = await self.skills.aapply(inputs, runtime=runtime)
else:
predictions = inputs

response = await skill.aimprove(
predictions=predictions,
teacher_runtime=teacher_runtime,
target_input_variables=input_variables,
instructions=instructions,
)
return response

Expand Down
23 changes: 14 additions & 9 deletions adala/skills/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ async def aimprove(
teacher_runtime: AsyncRuntime,
target_input_variables: List[str],
predictions: Optional[InternalDataFrame] = None,
instructions: Optional[str] = None,
):
"""
Improves the skill.
Expand All @@ -501,6 +502,7 @@ async def aimprove(
prompt_improvement_skill = PromptImprovementSkill(
skill_to_improve=self,
input_variables=target_input_variables,
instructions=instructions,
)
if predictions is None:
input_df = InternalDataFrame()
Expand Down Expand Up @@ -613,6 +615,11 @@ class AnalysisSkill(Skill):
def _iter_over_chunks(
self, input: InternalDataFrame, chunk_size: Optional[int] = None
):
"""
Iterates over chunks of the input dataframe.
Returns a generator of strings that are the concatenation of the rows of the chunk with `input_separator`
interpolated with the `input_template` and `extra_fields`.
"""

if input.empty:
yield ""
Expand All @@ -624,7 +631,7 @@ def _iter_over_chunks(
input = InternalDataFrame([input])

extra_fields = self._get_extra_fields()

# if chunk_size is specified, split the input into chunks and process each chunk separately
if self.chunk_size is not None:
chunks = (
Expand All @@ -633,21 +640,19 @@ def _iter_over_chunks(
)
else:
chunks = [input]

# define the row preprocessing function
def row_preprocessing(row):
return partial_str_format(self.input_template, **row, **extra_fields, i=int(row.name) + 1)

total = input.shape[0] // self.chunk_size if self.chunk_size is not None else 1
for chunk in tqdm(chunks, desc="Processing chunks", total=total):
# interpolate every row with input_template and concatenate them with input_separator to produce a single string
agg_chunk = (
chunk.reset_index()
.apply(
lambda row: partial_str_format(
self.input_template,
**row, **extra_fields, i=int(row.name) + 1
),
axis=1,
)
.apply(row_preprocessing, axis=1)
.str.cat(sep=self.input_separator)
)

yield agg_chunk

def apply(
Expand Down
177 changes: 41 additions & 136 deletions adala/skills/collection/prompt_improvement.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json
import logging
from pydantic import BaseModel, field_validator, Field, ConfigDict, model_validator
from pydantic import BaseModel, field_validator, Field, ConfigDict, model_validator, AfterValidator
from adala.skills import Skill
from typing import Any, Dict, List, Optional, Union
from typing_extensions import Annotated
from adala.skills import AnalysisSkill
from adala.utils.parse import parse_template
from adala.utils.types import ErrorResponseModel
Expand All @@ -11,35 +12,20 @@
logger = logging.getLogger(__name__)


def validate_used_variables(value: str) -> str:
templates = parse_template(value, include_texts=False)
if not templates:
raise ValueError("At least one input variable must be used in the prompt, formatted with curly braces like this: {input_variable}")
return value


class PromptImprovementSkillResponseModel(BaseModel):

reasoning: str = Field(
..., description="The reasoning for the changes made to the prompt"
)
new_prompt_title: str = Field(..., description="The new short title for the prompt")
new_prompt_content: str = Field(..., description="The new content for the prompt")

# model_config = ConfigDict(
# # omit other fields
# extra="ignore",
# # guard against name collisions with other fields
# populate_by_name=False,
# )

# @field_validator("new_prompt_content", mode="after")
# def validate_used_variables(cls, value: str) -> str:

# templates = parse_template(value, include_texts=False)
# if not templates:
# raise ValueError("At least one input variable must be used in the prompt")

# input_vars_used = [t["text"] for t in templates]
# if extra_vars_used := set(input_vars_used) - set(cls._input_variables):
# raise ValueError(
# f"Invalid variable used in prompt: {extra_vars_used}. Valid variables are: {cls._input_variables}"
# )

# return value
new_prompt_content: Annotated[str, AfterValidator(validate_used_variables)]


class ImprovedPromptResponse(BaseModel):
Expand All @@ -65,139 +51,58 @@ class PromptImprovementSkill(AnalysisSkill):
input_variables: List[str]

name: str = "prompt_improvement"
instructions: str = "" # Automatically generated
input_template: str = "" # Not used
input_prefix: str = (
"Here are a few prediction results after applying the current prompt for your analysis.\n\n"
)
input_separator: str = "\n\n"
instructions: str = "Improve current prompt"
input_template: str = "" # Used to provide a few shot examples of input-output pairs
input_prefix: str = "" # Used to provide additional context for the input
input_separator: str = "\n"

response_model = PromptImprovementSkillResponseModel

@model_validator(mode="after")
def validate_prompts(self):

input_variables = "\n".join(self.input_variables)
def get_json_template(fields):
json_body = ", ".join([f'"{field}": "{{{field}}}"' for field in fields])
return "{" + json_body + "}"

if isinstance(self.skill_to_improve, LabelStudioSkill):
model_json_schema = self.skill_to_improve.field_schema
else:
model_json_schema = self.skill_to_improve.response_model.model_json_schema()
# rewrite the instructions with the actual values
self.instructions = f"""\
You are a prompt engineer tasked with generating or enhancing a prompt for a Language Learning Model (LLM). Your goal is to create an effective prompt based on the given context, input data and requirements.
First, carefully review the following context information:
# Given context

## Task name
{self.skill_to_improve.name}
## Task description
{self.skill_to_improve.description}
## Input variables to use
{input_variables}
## Target response schema
```json
{json.dumps(model_json_schema, indent=2)}
input_variables = self.input_variables
output_variables = list(model_json_schema['properties'].keys())
input_json_template = get_json_template(input_variables)
output_json_template = get_json_template(output_variables)
self.input_template = f'{input_json_template} --> {output_json_template}'

self.input_prefix = f'''
## Current prompt:
```
Now, examine the current prompt (if provided):
# Current prompt
{self.skill_to_improve.input_template}
If a current prompt is provided, analyze it for potential improvements or errors. Consider how well it addresses the task description, input data and if it effectively utilizes all provided input variables.
Before creating the new prompt, provide a detailed reasoning for your choices. Using only bullet points, list the changes to be made and concisely explain why each change is necessary. Include:
1. How you plan to address the context and task description
2. Specific errors or improvements you've identified in the previous prompt (if applicable)
3. How you intend to tailor the new prompt to better suit the target model provider
4. Your strategy for designing the prompt to generate responses matching the provided schema
Next, generate a new short prompt title that accurately reflects the task and purpose of the prompt.
Finally, create the new prompt content. Ensure that you:
1. Incorporate all provided input variables, formatted with "{{" and "}}" brackets
2. Address the specific task description provided in the context
3. Consider the target model provider's capabilities and limitations
4. Maintain or improve upon any relevant information from the current prompt (if provided)
5. Structure the prompt to elicit a response that matches the provided response schema
Present your output in JSON format including the following fields:
- reasoning
- new_prompt_title
- new_prompt_content
# Example of the expected input and output:
Input context:
## Target model provider
OpenAI
## Task description
Generate a summary of the input text.
## Allowed input variables
text
document_metadata
## Target response schema
```json
{{
"summary": {{
"type": "string"
}},
"categories": {{
"type": "string",
"enum": ["news", "science", "politics", "sports", "entertainment"]
}}
}}
```
Check the following example to see how the model should respond:
## Current Labeling Config:
```xml
{self.skill_to_improve.label_config}
```
Current prompt:
## Input variables:
```
Generate a summary of the input text: "{{text}}".
{input_variables}
```
# Current prompt output
Generate a summary of the input text: "The quick brown fox jumps over the lazy dog." --> {{"summary": "The quick brown fox jumps over the lazy dog.", "categories": "news"}}
Generate a summary of the input text: "When was the Battle of Hastings?" --> {{"summary": "The Battle of Hastings was a decisive Norman victory in 1066, marking the end of Anglo-Saxon rule in England.", "categories": "history"}}
Generate a summary of the input text: "What is the capital of France?" --> {{ "summary": "The capital of France is Paris.", "categories": "geography"}}
Your output:
## Model response schema:
```json
{{
"reasoning": "Changes needed:
• Specify format and style of summary: The current prompt is vague, leading to inconsistent outputs.
• Add categories instructions: The prompt doesn't mention categorization, resulting in missing or incorrect categories.
• Use all input variables: The 'document_metadata' variable is not utilized in the current prompt.
• Align with response schema: Ensure only allowed categories are used (e.g., remove 'history').
• Improve clarity: Guide the model to summarize the text rather than answering questions.
These changes will ensure higher quality, more consistent responses that meet the specified requirements.",
"new_prompt_title": "Including categories instructions in the summary",
"new_prompt_content": "Generate a detailed summary of the input text:\n'''{{text}}'''.\nUse the document metadata to guide the model to produce categories.\n#Metadata:\n'''{{document_metadata}}'''.\nEnsure high quality output by asking the model to produce a detailed summary and to categorize the document."
}}
{json.dumps(model_json_schema, indent=2)}
```
Ensure that your refined prompt is clear, concise, and effectively guides the LLM to produce high quality responses.
"""
## Input-Output Examples:
# Create the output template for JSON output based on the response model fields
fields = self.skill_to_improve.response_model.model_fields
field_template = ", ".join([f'"{field}": "{{{field}}}"' for field in fields])
self.output_template = "{{" + field_template + "}}"
self.input_template = (
f"{self.skill_to_improve.input_template} --> {self.output_template}"
)
'''

# TODO: deprecated, leave self.output_template for compatibility
self.output_template = output_json_template

logger.debug(f'Instructions: {self.instructions}\nInput template: {self.input_template}\nInput prefix: {self.input_prefix}')
return self
6 changes: 3 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ celery = {version = "^5.3.6", extras = ["redis"]}
kombu = ">=5.4.0rc2" # Pin version to fix https://github.com/celery/celery/issues/8030. TODO: remove when this fix will be included in celery
uvicorn = "*"
pydantic-settings = "^2.2.1"
label-studio-sdk = {url = "https://github.com/HumanSignal/label-studio-sdk/archive/238b0f48dcac00d78a6e97862ea82011527307ce.zip"}
label-studio-sdk = {url = "https://github.com/HumanSignal/label-studio-sdk/archive/a588ba578c7ee12e80b244dac6bd09331d0b95eb.zip"}
kafka-python = "^2.0.2"
# https://github.com/geerlingguy/ansible-role-docker/issues/462#issuecomment-2144121102
requests = "2.31.0"
Expand Down
11 changes: 10 additions & 1 deletion server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,14 @@ class ImprovedPromptRequest(BaseModel):
default=None,
description="Batch of data to run the skill on",
)
reapply: bool = Field(
default=False,
description="Whether to reapply the skill to the data before improving the prompt",
)
instructions: Optional[str] = Field(
default='Improve current prompt',
description="Instructions for the prompt improvement task",
)

@field_validator("agent", mode="after")
def validate_teacher_runtime(cls, agent: Agent) -> Agent:
Expand All @@ -404,11 +412,12 @@ async def improved_prompt(request: ImprovedPromptRequest):
Returns:
Response: Response model for prompt improvement skill
"""

improved_prompt_response = await request.agent.arefine_skill(
skill_name=request.skill_to_improve,
input_variables=request.input_variables,
data=request.data,
reapply=request.reapply,
instructions=request.instructions,
)

return Response[ImprovedPromptResponse](
Expand Down

0 comments on commit c34fd15

Please sign in to comment.