Skip to content

Commit

Permalink
fix: DIA-1508: Fix sdk to support Prompts MultiSkill (#233)
Browse files Browse the repository at this point in the history
Co-authored-by: nik <[email protected]>
Co-authored-by: hakan458 <[email protected]>
Co-authored-by: niklub <[email protected]>
Co-authored-by: niklub <[email protected]>
  • Loading branch information
5 people authored Oct 25, 2024
1 parent c31b3b9 commit 43bfb50
Show file tree
Hide file tree
Showing 7 changed files with 844 additions and 50 deletions.
103 changes: 62 additions & 41 deletions adala/skills/collection/entity_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,67 @@
logger = logging.getLogger(__name__)


def extract_indices(
df,
input_field_name,
output_field_name,
quote_string_field_name='quote_string',
labels_field_name='label'
):
"""
Give the input dataframe with "text" column and "entities" column of the format
```
[{"quote_string": "entity_1"}, {"quote_string": "entity_2"}, ...]
```
extract the indices of the entities in the input text and put indices in the "entities" column:
```
[{"quote_string": "entity_1", "start": 0, "end": 5}, {"quote_string": "entity_2", "start": 10, "end": 15}, ...]
```
"""
for i, row in df.iterrows():
if row.get("_adala_error"):
logger.warning(f"Error in row {i}: {row['_adala_message']}")
continue
text = row[input_field_name]
entities = row[output_field_name]
to_remove = []
found_entities_ends = {}
for entity in entities:
# TODO: current naive implementation uses exact string matching which can seem to be a baseline
# we can improve this further by handling ambiguity, for example:
# - requesting surrounding context from LLM
# - perform fuzzy matching over strings if model still hallucinates when copying the text
ent_str = entity[quote_string_field_name]
# to avoid overlapping entities, start from the end of the last entity with the same prefix
matching_end_indices = [
found_entities_ends[found_ent]
for found_ent in found_entities_ends
if found_ent.startswith(ent_str)
]
if matching_end_indices:
# start searching from the end of the last entity with the same prefix
start_search_idx = max(matching_end_indices)
else:
# start searching from the beginning
start_search_idx = 0

start_idx = text.lower().find(
entity[quote_string_field_name].lower(),
start_search_idx,
)
if start_idx == -1:
# we need to remove the entity if it is not found in the text
to_remove.append(entity)
else:
end_index = start_idx + len(entity[quote_string_field_name])
entity["start"] = start_idx
entity["end"] = end_index
found_entities_ends[ent_str] = end_index
for entity in to_remove:
entities.remove(entity)
return df


def validate_schema(schema: Dict[str, Any]):
expected_schema = {
"type": "object",
Expand Down Expand Up @@ -266,47 +327,7 @@ def extract_indices(self, df):
"""
input_field_name = self._get_input_field_name()
output_field_name = self._get_output_field_name()
for i, row in df.iterrows():
if row.get("_adala_error"):
logger.warning(f"Error in row {i}: {row['_adala_message']}")
continue
text = row[input_field_name]
entities = row[output_field_name]
to_remove = []
found_entities_ends = {}
for entity in entities:
# TODO: current naive implementation uses exact string matching which can seem to be a baseline
# we can improve this further by handling ambiguity, for example:
# - requesting surrounding context from LLM
# - perform fuzzy matching over strings if model still hallucinates when copying the text
ent_str = entity[self._quote_string_field_name]
# to avoid overlapping entities, start from the end of the last entity with the same prefix
matching_end_indices = [
found_entities_ends[found_ent]
for found_ent in found_entities_ends
if found_ent.startswith(ent_str)
]
if matching_end_indices:
# start searching from the end of the last entity with the same prefix
start_search_idx = max(matching_end_indices)
else:
# start searching from the beginning
start_search_idx = 0

start_idx = text.lower().find(
entity[self._quote_string_field_name].lower(),
start_search_idx,
)
if start_idx == -1:
# we need to remove the entity if it is not found in the text
to_remove.append(entity)
else:
end_index = start_idx + len(entity[self._quote_string_field_name])
entity["start"] = start_idx
entity["end"] = end_index
found_entities_ends[ent_str] = end_index
for entity in to_remove:
entities.remove(entity)
df = extract_indices(df, input_field_name, output_field_name, self._quote_string_field_name, self._labels_field_name)
return df

def apply(
Expand Down
23 changes: 20 additions & 3 deletions adala/skills/collection/label_studio.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from typing import Dict, Any, Type
import pandas as pd
from typing import Optional, Type
from functools import cached_property
from adala.skills._base import TransformSkill
from pydantic import BaseModel, Field, model_validator
Expand All @@ -8,8 +9,10 @@
from adala.utils.internal_data import InternalDataFrame

from label_studio_sdk.label_interface import LabelInterface
from label_studio_sdk.label_interface.control_tags import ControlTag
from label_studio_sdk._extensions.label_studio_tools.core.utils.json_schema import json_schema_to_pydantic

from .entity_extraction import extract_indices

logger = logging.getLogger(__name__)

Expand All @@ -24,7 +27,14 @@ class LabelStudioSkill(TransformSkill):
# ------------------------------
label_config: str = "<View></View>"

# TODO: implement postprocessing like in EntityExtractionSkill or to verify Taxonomy
# TODO: implement postprocessing to verify Taxonomy

def has_ner_tag(self) -> Optional[ControlTag]:
# check if the input config has NER tag (<Labels> + <Text>), and return its `from_name` and `to_name`
interface = LabelInterface(self.label_config)
for tag in interface.controls:
if tag.tag == 'Labels':
return tag

@model_validator(mode='after')
def validate_response_model(self):
Expand Down Expand Up @@ -62,10 +72,17 @@ async def aapply(
) -> InternalDataFrame:

with json_schema_to_pydantic(self.field_schema) as ResponseModel:
return await runtime.batch_to_batch(
output = await runtime.batch_to_batch(
input,
input_template=self.input_template,
output_template="",
instructions_template=self.instructions,
response_model=ResponseModel,
)
ner_tag = self.has_ner_tag()
if ner_tag:
input_field_name = ner_tag.objects[0].value.lstrip('$')
output_field_name = ner_tag.name
quote_string_field_name = 'text'
output = extract_indices(pd.concat([input, output], axis=1), input_field_name, output_field_name, quote_string_field_name)
return output
9 changes: 7 additions & 2 deletions adala/skills/collection/prompt_improvement.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from adala.skills import AnalysisSkill
from adala.utils.parse import parse_template
from adala.utils.types import ErrorResponseModel
from adala.skills.collection.label_studio import LabelStudioSkill

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -75,8 +76,12 @@ class PromptImprovementSkill(AnalysisSkill):

@model_validator(mode="after")
def validate_prompts(self):

input_variables = "\n".join(self.input_variables)

if isinstance(self.skill_to_improve, LabelStudioSkill):
model_json_schema = self.skill_to_improve.field_schema
else:
model_json_schema = self.skill_to_improve.response_model.model_json_schema()
# rewrite the instructions with the actual values
self.instructions = f"""\
You are a prompt engineer tasked with generating or enhancing a prompt for a Language Learning Model (LLM). Your goal is to create an effective prompt based on the given context, input data and requirements.
Expand All @@ -96,7 +101,7 @@ def validate_prompts(self):
## Target response schema
```json
{json.dumps(self.skill_to_improve.response_model.model_json_schema(), indent=2)}
{json.dumps(model_json_schema, indent=2)}
```
Now, examine the current prompt (if provided):
Expand Down
6 changes: 3 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ celery = {version = "^5.3.6", extras = ["redis"]}
kombu = ">=5.4.0rc2" # Pin version to fix https://github.com/celery/celery/issues/8030. TODO: remove when this fix will be included in celery
uvicorn = "*"
pydantic-settings = "^2.2.1"
label-studio-sdk = {url = "https://github.com/HumanSignal/label-studio-sdk/archive/e94072130d90f7b89701b9234ce175425f36f23e.zip"}
label-studio-sdk = {url = "https://github.com/HumanSignal/label-studio-sdk/archive/238b0f48dcac00d78a6e97862ea82011527307ce.zip"}
kafka-python = "^2.0.2"
# https://github.com/geerlingguy/ansible-role-docker/issues/462#issuecomment-2144121102
requests = "2.31.0"
Expand Down
Loading

0 comments on commit 43bfb50

Please sign in to comment.