diff --git a/spacy_llm/tasks/rel/task.py b/spacy_llm/tasks/rel/task.py index fe9f6e4e..87f90622 100644 --- a/spacy_llm/tasks/rel/task.py +++ b/spacy_llm/tasks/rel/task.py @@ -71,7 +71,9 @@ def _preannotate(doc: Union[Doc, RELExample]) -> str: for i, ent in enumerate(doc.ents): end = ent.end_char before, after = text[: end + offset], text[end + offset :] - annotation = f"[ENT{i}:{ent.label}]" + annotation = ( + f"[ENT{i}:{ent.label if isinstance(doc, RELExample) else ent.label_}]" + ) offset += len(annotation) text = f"{before}{annotation}{after}" diff --git a/spacy_llm/tests/tasks/test_rel.py b/spacy_llm/tests/tasks/test_rel.py index eb685f2e..e4c110ae 100644 --- a/spacy_llm/tests/tasks/test_rel.py +++ b/spacy_llm/tests/tasks/test_rel.py @@ -266,3 +266,23 @@ def test_incorrect_indexing(): ) == 0 ) + + +@pytest.mark.external +@pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") +def test_labels_in_prompt(request: FixtureRequest): + """See https://github.com/explosion/spacy-llm/issues/366.""" + config = Config().from_str(request.getfixturevalue("zeroshot_cfg_string")) + config["components"].pop("ner") + config.pop("initialize") + config["nlp"]["pipeline"] = ["llm"] + config["components"]["llm"]["task"]["labels"] = ["A", "B", "C"] + nlp = assemble_from_config(config) + + doc = Doc(get_lang_class("en")().vocab, words=["Well", "hello", "there"]) + doc.ents = [Span(doc, 0, 1, "A"), Span(doc, 1, 2, "B"), Span(doc, 2, 3, "C")] + + assert ( + "Well[ENT0:A] hello[ENT1:B] there[ENT2:C]" + in list(nlp.get_pipe("llm")._task.generate_prompts([doc]))[0] + )