Skip to content

Commit

Permalink
Add label definition to add function for LabeledTask (#340)
Browse files Browse the repository at this point in the history
* Additional clear functionality
optional add label_definition

* Small fix

* fix typos

* Adjusted test case label to catch
non working definition

* Fix tests

* reformat files

---------

Co-authored-by: Habib H <[email protected]>
  • Loading branch information
habibhaidari1 and Habib H authored Oct 30, 2023
1 parent 7a0460c commit e8cb182
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 4 deletions.
9 changes: 7 additions & 2 deletions spacy_llm/pipeline/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,15 @@ def labels(self) -> Tuple[str, ...]:
labels = self._task.labels
return labels

def add_label(self, label: str) -> int:
def add_label(self, label: str, label_definition: Optional[str] = None) -> int:
if not isinstance(self._task, LabeledTask):
raise ValueError("The task of this LLM component does not have labels.")
return self._task.add_label(label)
return self._task.add_label(label, label_definition)

def clear(self) -> None:
if not isinstance(self._task, LabeledTask):
raise ValueError("The task of this LLM component does not have labels.")
return self._task.clear()

@property
def task(self) -> LLMTask:
Expand Down
13 changes: 12 additions & 1 deletion spacy_llm/tasks/builtin_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,14 +312,25 @@ def _extract_labels_from_example(self, example: Example) -> List[str]:
def labels(self) -> Tuple[str, ...]:
return tuple(self._label_dict.values())

def add_label(self, label: str) -> int:
def add_label(self, label: str, label_definition: Optional[str] = None) -> int:
"""Add a label to the task"""
if not isinstance(label, str):
raise ValueError(Errors.E187)
if label in self.labels:
return 0
self._label_dict[self._normalizer(label)] = label
if label_definition is None:
return 1
if self._label_definitions is None:
self._label_definitions = {}
self._label_definitions[label] = label_definition
return 1

def clear(self) -> None:
"""Reset all labels."""
self._label_dict = {}
self._label_definitions = None

@property
def normalizer(self) -> Callable[[str], str]:
return self._normalizer
Expand Down
34 changes: 34 additions & 0 deletions spacy_llm/tests/tasks/test_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,7 +992,41 @@ def test_add_label():
doc = nlp(text)
assert len(doc.ents) == 0

for label, definition in [
("PERSON", "Every person with the name Jack"),
("LOCATION", None),
]:
llm.add_label(label, definition)
doc = nlp(text)
assert len(doc.ents) == 2


@pytest.mark.external
@pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available")
def test_clear_label():
nlp = spacy.blank("en")
llm = nlp.add_pipe(
"llm",
config={
"task": {
"@llm_tasks": "spacy.NER.v3",
},
"model": {
"@llm_models": "spacy.GPT-3-5.v1",
},
},
)

nlp.initialize()
text = "Jack and Jill visited France."
doc = nlp(text)

for label in ["PERSON", "LOCATION"]:
llm.add_label(label)
doc = nlp(text)
assert len(doc.ents) == 3

llm.clear()

doc = nlp(text)
assert len(doc.ents) == 0
5 changes: 4 additions & 1 deletion spacy_llm/ty.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,10 @@ class LabeledTask(Protocol):
def labels(self) -> Tuple[str, ...]:
...

def add_label(self, label: str) -> int:
def add_label(self, label: str, label_definition: Optional[str] = None) -> int:
...

def clear(self) -> None:
...


Expand Down

0 comments on commit e8cb182

Please sign in to comment.