Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: DIA-1511: Support data batch context for autorefinement #232

Merged
merged 7 commits into from
Oct 25, 2024
7 changes: 3 additions & 4 deletions adala/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ async def arefine_skill(
self,
skill_name: str,
input_variables: List[str],
batch_data: Optional[BatchData] = None,
data: Optional[List[Dict]] = None,
) -> ImprovedPromptResponse:
"""
beta v2 of Agent.learn() that is:
Expand All @@ -412,7 +412,6 @@ async def arefine_skill(
Limitations so far:
- single skill at a time
- only returns the improved input_template, doesn't modify the skill in place
- doesn't use examples/feedback
- no iterations/variable cost
"""

Expand All @@ -426,10 +425,10 @@ async def arefine_skill(

# get inputs
# TODO: replace it with async environment.get_data_batch()
if batch_data is None:
if data is None:
predictions = None
else:
inputs = InternalDataFrame.from_records(batch_data or [])
inputs = InternalDataFrame.from_records(data or [])
predictions = await self.skills.aapply(inputs, runtime=runtime)

response = await skill.aimprove(
Expand Down
44 changes: 30 additions & 14 deletions adala/skills/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,13 +479,23 @@ def improve(
new_prompt = runtime.get_llm_response(messages)
self.instructions = new_prompt


async def aimprove(self, teacher_runtime: AsyncRuntime, target_input_variables: List[str], predictions: Optional[InternalDataFrame] = None):
async def aimprove(
self,
teacher_runtime: AsyncRuntime,
target_input_variables: List[str],
predictions: Optional[InternalDataFrame] = None,
):
"""
Improves the skill.
"""

from adala.skills.collection.prompt_improvement import PromptImprovementSkill, ImprovedPromptResponse, ErrorResponseModel, PromptImprovementSkillResponseModel
from adala.skills.collection.prompt_improvement import (
PromptImprovementSkill,
ImprovedPromptResponse,
ErrorResponseModel,
PromptImprovementSkillResponseModel,
)

response_dct = {}
try:
prompt_improvement_skill = PromptImprovementSkill(
Expand All @@ -500,7 +510,7 @@ async def aimprove(self, teacher_runtime: AsyncRuntime, target_input_variables:
input=input_df,
runtime=teacher_runtime,
)

# awkward to go from response model -> dict -> df -> dict -> response model
response_dct = response_df.iloc[0].to_dict()

Expand All @@ -511,12 +521,14 @@ async def aimprove(self, teacher_runtime: AsyncRuntime, target_input_variables:
output = PromptImprovementSkillResponseModel(**response_dct)

except Exception as e:
logger.error(f"Error improving skill: {e}. Traceback: {traceback.format_exc()}")
logger.error(
f"Error improving skill: {e}. Traceback: {traceback.format_exc()}"
)
output = ErrorResponseModel(
_adala_message=str(e),
_adala_details=traceback.format_exc(),
)

# get tokens and token cost
resp = ImprovedPromptResponse(output=output, **response_dct)
logger.debug(f"resp: {resp}")
Expand Down Expand Up @@ -593,22 +605,24 @@ class AnalysisSkill(Skill):
Analysis skill that analyzes a dataframe and returns a record (e.g. for data analysis purposes).
See base class Skill for more information about the attributes.
"""

input_prefix: str = ""
input_separator: str = "\n"
chunk_size: Optional[int] = None

def _iter_over_chunks(self, input: InternalDataFrame, chunk_size: Optional[int] = None):
def _iter_over_chunks(
self, input: InternalDataFrame, chunk_size: Optional[int] = None
):

if input.empty:
yield ""
return

if isinstance(input, InternalSeries):
input = input.to_frame()
elif isinstance(input, dict):
input = InternalDataFrame([input])


extra_fields = self._get_extra_fields()

# if chunk_size is specified, split the input into chunks and process each chunk separately
Expand All @@ -622,16 +636,18 @@ def _iter_over_chunks(self, input: InternalDataFrame, chunk_size: Optional[int]

total = input.shape[0] // self.chunk_size if self.chunk_size is not None else 1
for chunk in tqdm(chunks, desc="Processing chunks", total=total):
agg_chunk = chunk\
.reset_index()\
agg_chunk = (
chunk.reset_index()
.apply(
lambda row: partial_str_format(
self.input_template,
**row, **extra_fields, i=int(row.name) + 1
),
axis=1,
).str.cat(sep=self.input_separator)

)
.str.cat(sep=self.input_separator)
)

yield agg_chunk

def apply(
Expand Down Expand Up @@ -663,7 +679,7 @@ def apply(
output = InternalDataFrame(outputs)

return output

async def aapply(
self,
input: Union[InternalDataFrame, InternalSeries, Dict],
Expand Down
28 changes: 17 additions & 11 deletions adala/skills/collection/prompt_improvement.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@

class PromptImprovementSkillResponseModel(BaseModel):


reasoning: str = Field(..., description="The reasoning for the changes made to the prompt")
reasoning: str = Field(
..., description="The reasoning for the changes made to the prompt"
)
new_prompt_title: str = Field(..., description="The new short title for the prompt")
new_prompt_content: str = Field(..., description="The new content for the prompt")

Expand Down Expand Up @@ -49,7 +50,9 @@ class ImprovedPromptResponse(BaseModel):

# these can fail to calculate
prompt_cost_usd: Optional[float] = Field(alias="_prompt_cost_usd", default=None)
completion_cost_usd: Optional[float] = Field(alias="_completion_cost_usd", default=None)
completion_cost_usd: Optional[float] = Field(
alias="_completion_cost_usd", default=None
)
total_cost_usd: Optional[float] = Field(alias="_total_cost_usd", default=None)


Expand All @@ -61,18 +64,19 @@ class PromptImprovementSkill(AnalysisSkill):
input_variables: List[str]

name: str = "prompt_improvement"
instructions: str = "" # Automatically generated
instructions: str = "" # Automatically generated
input_template: str = "" # Not used
input_prefix: str = "Here are a few prediction results after applying the current prompt for your analysis.\n\n"
input_prefix: str = (
"Here are a few prediction results after applying the current prompt for your analysis.\n\n"
)
input_separator: str = "\n\n"

response_model = PromptImprovementSkillResponseModel


@model_validator(mode="after")
def validate_prompts(self):
input_variables = '\n'.join(self.input_variables)
input_variables = "\n".join(self.input_variables)

# rewrite the instructions with the actual values
self.instructions = f"""\
You are a prompt engineer tasked with generating or enhancing a prompt for a Language Learning Model (LLM). Your goal is to create an effective prompt based on the given context, input data and requirements.
Expand Down Expand Up @@ -183,10 +187,12 @@ def validate_prompts(self):
Ensure that your refined prompt is clear, concise, and effectively guides the LLM to produce high quality responses.

"""

# Create the output template for JSON output based on the response model fields
fields = self.skill_to_improve.response_model.model_fields
field_template = ", ".join([f'"{field}": "{{{field}}}"'for field in fields])
field_template = ", ".join([f'"{field}": "{{{field}}}"' for field in fields])
self.output_template = "{{" + field_template + "}}"
self.input_template = f"{self.skill_to_improve.input_template} --> {self.output_template}"
self.input_template = (
f"{self.skill_to_improve.input_template} --> {self.output_template}"
)
return self
7 changes: 4 additions & 3 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ class ImprovedPromptRequest(BaseModel):
default=None,
description="List of variables available to use in the input template of the skill, in case any exist that are not currently used",
)
batch_data: Optional[BatchData] = Field(
data: Optional[List[Dict]] = Field(
default=None,
description="Batch of data to run the skill on",
)
Expand Down Expand Up @@ -408,14 +408,15 @@ async def improved_prompt(request: ImprovedPromptRequest):
improved_prompt_response = await request.agent.arefine_skill(
skill_name=request.skill_to_improve,
input_variables=request.input_variables,
batch_data=request.batch_data.data if request.batch_data else None
data=request.data,
)

return Response[ImprovedPromptResponse](
success=not isinstance(improved_prompt_response.output, ErrorResponseModel),
data=improved_prompt_response
data=improved_prompt_response,
)


if __name__ == "__main__":
# for debugging
uvicorn.run("app:app", host="0.0.0.0", port=30001)
22 changes: 11 additions & 11 deletions tests/test_refine_skill.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ async def test_arefine_skill_no_input_data(client, agent_json):
}

response = client.post("/improved-prompt", json=payload)

assert response.status_code == 200
result = response.json()

Expand All @@ -73,7 +73,7 @@ async def test_arefine_skill_no_input_data(client, agent_json):
assert "reasoning" in output
assert "new_prompt_title" in output
assert "new_prompt_content" in output
assert '{text}' in output["new_prompt_content"]
assert "{text}" in output["new_prompt_content"]


@pytest.mark.use_openai
Expand All @@ -90,14 +90,11 @@ async def test_arefine_skill_with_input_data(client, agent_json):
"agent": agent_json,
"skill_to_improve": skill_name,
"input_variables": ["text", "id"],
"batch_data": {
'job_id': '123',
'data': batch_data,
}
"data": batch_data,
}

response = client.post("/improved-prompt", json=payload)

assert response.status_code == 200
result = response.json()

Expand All @@ -108,8 +105,8 @@ async def test_arefine_skill_with_input_data(client, agent_json):
assert "reasoning" in output
assert "new_prompt_title" in output
assert "new_prompt_content" in output
assert '{text}' in output["new_prompt_content"]
assert '{id}' in output["new_prompt_content"]
assert "{text}" in output["new_prompt_content"]
assert "{id}" in output["new_prompt_content"]


@pytest.mark.use_openai
Expand Down Expand Up @@ -140,7 +137,7 @@ def side_effect(*args, **kwargs):
return mock_create.return_value

mock_create.side_effect = side_effect

resp = client.post(
"/improved-prompt",
json={
Expand All @@ -152,4 +149,7 @@ def side_effect(*args, **kwargs):
assert resp.raise_for_status()
resp_json = resp.json()
assert not resp_json["success"]
assert f"Simulated OpenAI API failure for {skill_name}" == resp_json["data"]["output"]["_adala_details"]
assert (
f"Simulated OpenAI API failure for {skill_name}"
== resp_json["data"]["output"]["_adala_details"]
)
Loading