Skip to content

Commit

Permalink
feat: DIA-1685: [sdk] Create example predictions and annotations from…
Browse files Browse the repository at this point in the history
… a LabelConfig (#360)
  • Loading branch information
matt-bernstein authored Nov 26, 2024
1 parent 8093e3d commit 179cec9
Show file tree
Hide file tree
Showing 4 changed files with 215 additions and 16 deletions.
152 changes: 150 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ typing_extensions = ">= 4.0.0"
ujson = ">=5.8.0"
xmljson = "0.2.1"

jsf = "^0.11.2"
[tool.poetry.dev-dependencies]
mypy = "1.0.1"
pytest = "^7.4.0"
Expand Down
67 changes: 63 additions & 4 deletions src/label_studio_sdk/label_interface/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from collections import defaultdict, OrderedDict
from lxml import etree
import xmljson
from jsf import JSF

from label_studio_sdk._legacy.exceptions import (
LSConfigParseException,
Expand Down Expand Up @@ -770,7 +771,7 @@ def validate_region(self, region) -> bool:
return False

# type of the region should match the tag name
if control.tag.lower() != region["type"]:
if control.tag.lower() != region["type"].lower():
return False

# make sure that in config it connects to the same tag as
Expand Down Expand Up @@ -839,9 +840,67 @@ def generate_sample_task(self, mode="upload", secure_mode=False):

return task

def generate_sample_annotation(self):
""" """
raise NotImplemented()
def _generate_sample_regions(self):
""" Generate an example of each control tag's JSON schema and validate it as a region"""
return self.create_regions({
control.name: JSF(control.to_json_schema()).generate()
for control in self.controls
})

def generate_sample_prediction(self) -> Optional[dict]:
"""Generates a sample prediction that is valid for this label config.
Example:
{'model_version': 'sample model version',
'score': 0.0,
'result': [{'id': 'e7bd76e6-4e88-4eb3-b433-55e03661bf5d',
'from_name': 'sentiment',
'to_name': 'text',
'type': 'choices',
'value': {'choices': ['Neutral']}}]}
NOTE: `id` field in result is not required when importing predictions; it will be generated automatically.
NOTE: for each control tag, depends on tag.to_json_schema() being implemented correctly
"""
prediction = PredictionValue(
model_version='sample model version',
result=self._generate_sample_regions()
)
prediction_dct = prediction.model_dump()
if self.validate_prediction(prediction_dct):
return prediction_dct
else:
logger.debug(f'Sample prediction {prediction_dct} failed validation for label config {self.config}')
return None

def generate_sample_annotation(self) -> Optional[dict]:
"""Generates a sample annotation that is valid for this label config.
Example:
{'was_cancelled': False,
'ground_truth': False,
'lead_time': 0.0,
'result_count': 0,
'completed_by': -1,
'result': [{'id': 'b05da11d-3ffc-4657-8b8d-f5bc37cd59ac',
'from_name': 'sentiment',
'to_name': 'text',
'type': 'choices',
'value': {'choices': ['Negative']}}]}
NOTE: `id` field in result is not required when importing predictions; it will be generated automatically.
NOTE: for each control tag, depends on tag.to_json_schema() being implemented correctly
"""
annotation = AnnotationValue(
completed_by=-1, # annotator's user id
result=self._generate_sample_regions()
)
annotation_dct = annotation.model_dump()
if self.validate_annotation(annotation_dct):
return annotation_dct
else:
logger.debug(f'Sample annotation {annotation_dct} failed validation for label config {self.config}')
return None

#####
##### COMPATIBILITY LAYER
Expand Down
11 changes: 1 addition & 10 deletions src/label_studio_sdk/label_interface/region.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,7 @@

class Region(BaseModel):
"""
Class for Region Tag
Attributes:
-----------
id: str
The unique identifier of the region
x: int
The x coordinate of the region
y: int
A Region is an item in the `result` list of a PredictionValue or AnnotationValue.
"""

id: str = Field(default_factory=lambda: str(uuid4()))
Expand Down

0 comments on commit 179cec9

Please sign in to comment.