diff --git a/spacy_llm/pipeline/llm.py b/spacy_llm/pipeline/llm.py index 1c3365d4..70003a8b 100644 --- a/spacy_llm/pipeline/llm.py +++ b/spacy_llm/pipeline/llm.py @@ -138,10 +138,15 @@ def labels(self) -> Tuple[str, ...]: labels = self._task.labels return labels - def add_label(self, label: str) -> int: + def add_label(self, label: str, label_definition: Optional[str] = None) -> int: if not isinstance(self._task, LabeledTask): raise ValueError("The task of this LLM component does not have labels.") - return self._task.add_label(label) + return self._task.add_label(label, label_definition) + + def clear(self) -> None: + if not isinstance(self._task, LabeledTask): + raise ValueError("The task of this LLM component does not have labels.") + return self._task.clear() @property def task(self) -> LLMTask: diff --git a/spacy_llm/tasks/builtin_task.py b/spacy_llm/tasks/builtin_task.py index fa565a97..3c4ef027 100644 --- a/spacy_llm/tasks/builtin_task.py +++ b/spacy_llm/tasks/builtin_task.py @@ -312,14 +312,25 @@ def _extract_labels_from_example(self, example: Example) -> List[str]: def labels(self) -> Tuple[str, ...]: return tuple(self._label_dict.values()) - def add_label(self, label: str) -> int: + def add_label(self, label: str, label_definition: Optional[str] = None) -> int: + """Add a label to the task""" if not isinstance(label, str): raise ValueError(Errors.E187) if label in self.labels: return 0 self._label_dict[self._normalizer(label)] = label + if label_definition is None: + return 1 + if self._label_definitions is None: + self._label_definitions = {} + self._label_definitions[label] = label_definition return 1 + def clear(self) -> None: + """Reset all labels.""" + self._label_dict = {} + self._label_definitions = None + @property def normalizer(self) -> Callable[[str], str]: return self._normalizer diff --git a/spacy_llm/tests/tasks/test_ner.py b/spacy_llm/tests/tasks/test_ner.py index 7d7c8c82..97512ca6 100644 --- a/spacy_llm/tests/tasks/test_ner.py +++ b/spacy_llm/tests/tasks/test_ner.py @@ -992,7 +992,41 @@ def test_add_label(): doc = nlp(text) assert len(doc.ents) == 0 + for label, definition in [ + ("PERSON", "Every person with the name Jack"), + ("LOCATION", None), + ]: + llm.add_label(label, definition) + doc = nlp(text) + assert len(doc.ents) == 2 + + +@pytest.mark.external +@pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") +def test_clear_label(): + nlp = spacy.blank("en") + llm = nlp.add_pipe( + "llm", + config={ + "task": { + "@llm_tasks": "spacy.NER.v3", + }, + "model": { + "@llm_models": "spacy.GPT-3-5.v1", + }, + }, + ) + + nlp.initialize() + text = "Jack and Jill visited France." + doc = nlp(text) + for label in ["PERSON", "LOCATION"]: llm.add_label(label) doc = nlp(text) assert len(doc.ents) == 3 + + llm.clear() + + doc = nlp(text) + assert len(doc.ents) == 0 diff --git a/spacy_llm/ty.py b/spacy_llm/ty.py index 0673b1d0..c07c0acb 100644 --- a/spacy_llm/ty.py +++ b/spacy_llm/ty.py @@ -136,7 +136,10 @@ class LabeledTask(Protocol): def labels(self) -> Tuple[str, ...]: ... - def add_label(self, label: str) -> int: + def add_label(self, label: str, label_definition: Optional[str] = None) -> int: + ... + + def clear(self) -> None: ...