From 41a3aa598a5e05e6cb2f969807057a3da73d259d Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Mon, 13 Nov 2023 16:28:14 +0100 Subject: [PATCH 1/2] Fix REL label issue. --- spacy_llm/tasks/rel/task.py | 4 +++- spacy_llm/tests/tasks/test_rel.py | 20 ++++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/spacy_llm/tasks/rel/task.py b/spacy_llm/tasks/rel/task.py index fe9f6e4e..87f90622 100644 --- a/spacy_llm/tasks/rel/task.py +++ b/spacy_llm/tasks/rel/task.py @@ -71,7 +71,9 @@ def _preannotate(doc: Union[Doc, RELExample]) -> str: for i, ent in enumerate(doc.ents): end = ent.end_char before, after = text[: end + offset], text[end + offset :] - annotation = f"[ENT{i}:{ent.label}]" + annotation = ( + f"[ENT{i}:{ent.label if isinstance(doc, RELExample) else ent.label_}]" + ) offset += len(annotation) text = f"{before}{annotation}{after}" diff --git a/spacy_llm/tests/tasks/test_rel.py b/spacy_llm/tests/tasks/test_rel.py index eb685f2e..ae8ac196 100644 --- a/spacy_llm/tests/tasks/test_rel.py +++ b/spacy_llm/tests/tasks/test_rel.py @@ -266,3 +266,23 @@ def test_incorrect_indexing(): ) == 0 ) + + +@pytest.mark.external +@pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") +@pytest.mark.issue(366) +def test_labels(request: FixtureRequest): + config = Config().from_str(request.getfixturevalue("zeroshot_cfg_string")) + config["components"].pop("ner") + config.pop("initialize") + config["nlp"]["pipeline"] = ["llm"] + config["components"]["llm"]["task"]["labels"] = ["A", "B", "C"] + nlp = assemble_from_config(config) + + doc = Doc(get_lang_class("en")().vocab, words=["Well", "hello", "there"]) + doc.ents = [Span(doc, 0, 1, "A"), Span(doc, 1, 2, "B"), Span(doc, 2, 3, "C")] + + assert ( + "Well[ENT0:A] hello[ENT1:B] there[ENT2:C]" + in list(nlp.get_pipe("llm")._task.generate_prompts([doc]))[0] + ) From 7230b695f60e8b8e29e85266e27f7e254421ad9f Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Mon, 13 Nov 2023 16:34:02 +0100 Subject: [PATCH 2/2] Remove .issue mark. --- spacy_llm/tests/tasks/test_rel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spacy_llm/tests/tasks/test_rel.py b/spacy_llm/tests/tasks/test_rel.py index ae8ac196..e4c110ae 100644 --- a/spacy_llm/tests/tasks/test_rel.py +++ b/spacy_llm/tests/tasks/test_rel.py @@ -270,8 +270,8 @@ def test_incorrect_indexing(): @pytest.mark.external @pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") -@pytest.mark.issue(366) -def test_labels(request: FixtureRequest): +def test_labels_in_prompt(request: FixtureRequest): + """See https://github.com/explosion/spacy-llm/issues/366.""" config = Config().from_str(request.getfixturevalue("zeroshot_cfg_string")) config["components"].pop("ner") config.pop("initialize")