Skip to content

Commit

Permalink
fix: DIA-1523: test coverage for LabelStudioSkill (#247)
Browse files Browse the repository at this point in the history
  • Loading branch information
matt-bernstein authored Nov 7, 2024
1 parent 882ca68 commit 591c0ad
Show file tree
Hide file tree
Showing 4 changed files with 66,744 additions and 32 deletions.
27 changes: 27 additions & 0 deletions adala/skills/collection/entity_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,33 @@
logger = logging.getLogger(__name__)


def validate_output_format_for_ner_tag(df: InternalDataFrame, input_field_name: str, output_field_name: str):
'''
The output format for Labels is:
{
"start": start_idx,
"end": end_idx,
"text": text,
"labels": [label1, label2, ...]
}
Sometimes the model cannot populate "text" correctly, but this can be fixed deterministically.
'''
for i, row in df.iterrows():
if row.get("_adala_error"):
logger.warning(f"Error in row {i}: {row['_adala_message']}")
continue
text = row[input_field_name]
entities = row[output_field_name]
for entity in entities:
corrected_text = text[entity["start"]:entity["end"]]
if entity.get("text") is None:
entity["text"] = corrected_text
elif entity["text"] != corrected_text:
# this seems to happen rarely if at all in testing, but could lead to invalid predictions
logger.warning(f"text and indices disagree for a predicted entity")
return df


def extract_indices(
df,
input_field_name,
Expand Down
18 changes: 10 additions & 8 deletions adala/skills/collection/label_studio.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import pandas as pd
from typing import Optional, Type
from typing import Type, Iterator
from functools import cached_property
from adala.skills._base import TransformSkill
from pydantic import BaseModel, Field, model_validator
Expand All @@ -12,7 +12,7 @@
from label_studio_sdk.label_interface.control_tags import ControlTag
from label_studio_sdk._extensions.label_studio_tools.core.utils.json_schema import json_schema_to_pydantic

from .entity_extraction import extract_indices
from .entity_extraction import extract_indices, validate_output_format_for_ner_tag

logger = logging.getLogger(__name__)

Expand All @@ -29,13 +29,14 @@ class LabelStudioSkill(TransformSkill):

# TODO: implement postprocessing to verify Taxonomy

def has_ner_tag(self) -> Optional[ControlTag]:
def ner_tags(self) -> Iterator[ControlTag]:
# check if the input config has NER tag (<Labels> + <Text>), and return its `from_name` and `to_name`
interface = LabelInterface(self.label_config)
for tag in interface.controls:
#TODO: don't need to check object tag because at this point, unusable control tags should have been stripped out of the label config, but confirm this - maybe move this logic to LSE
if tag.tag == 'Labels':
return tag

yield tag
@model_validator(mode='after')
def validate_response_model(self):

Expand Down Expand Up @@ -79,10 +80,11 @@ async def aapply(
instructions_template=self.instructions,
response_model=ResponseModel,
)
ner_tag = self.has_ner_tag()
if ner_tag:
for ner_tag in self.ner_tags():
input_field_name = ner_tag.objects[0].value.lstrip('$')
output_field_name = ner_tag.name
quote_string_field_name = 'text'
output = extract_indices(pd.concat([input, output], axis=1), input_field_name, output_field_name, quote_string_field_name)
df = pd.concat([input, output], axis=1)
output = validate_output_format_for_ner_tag(df, input_field_name, output_field_name)
output = extract_indices(output, input_field_name, output_field_name, quote_string_field_name)
return output
Loading

0 comments on commit 591c0ad

Please sign in to comment.